Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| """ | |
| A script to benchmark builtin models. | |
| Note: this script has an extra dependency of psutil. | |
| """ | |
| import itertools | |
| import logging | |
| import psutil | |
| import torch | |
| import tqdm | |
| from fvcore.common.timer import Timer | |
| from torch.nn.parallel import DistributedDataParallel | |
| from detectron2.checkpoint import DetectionCheckpointer | |
| from detectron2.config import get_cfg | |
| from detectron2.data import ( | |
| DatasetFromList, | |
| build_detection_test_loader, | |
| build_detection_train_loader, | |
| ) | |
| from detectron2.engine import SimpleTrainer, default_argument_parser, hooks, launch | |
| from detectron2.modeling import build_model | |
| from detectron2.solver import build_optimizer | |
| from detectron2.utils import comm | |
| from detectron2.utils.events import CommonMetricPrinter | |
| from detectron2.utils.logger import setup_logger | |
| logger = logging.getLogger("detectron2") | |
| def setup(args): | |
| cfg = get_cfg() | |
| cfg.merge_from_file(args.config_file) | |
| cfg.SOLVER.BASE_LR = 0.001 # Avoid NaNs. Not useful in this script anyway. | |
| cfg.merge_from_list(args.opts) | |
| cfg.freeze() | |
| setup_logger(distributed_rank=comm.get_rank()) | |
| return cfg | |
| def benchmark_data(args): | |
| cfg = setup(args) | |
| timer = Timer() | |
| dataloader = build_detection_train_loader(cfg) | |
| logger.info("Initialize loader using {} seconds.".format(timer.seconds())) | |
| timer.reset() | |
| itr = iter(dataloader) | |
| for i in range(10): # warmup | |
| next(itr) | |
| if i == 0: | |
| startup_time = timer.seconds() | |
| timer = Timer() | |
| max_iter = 1000 | |
| for _ in tqdm.trange(max_iter): | |
| next(itr) | |
| logger.info( | |
| "{} iters ({} images) in {} seconds.".format( | |
| max_iter, max_iter * cfg.SOLVER.IMS_PER_BATCH, timer.seconds() | |
| ) | |
| ) | |
| logger.info("Startup time: {} seconds".format(startup_time)) | |
| vram = psutil.virtual_memory() | |
| logger.info( | |
| "RAM Usage: {:.2f}/{:.2f} GB".format( | |
| (vram.total - vram.available) / 1024 ** 3, vram.total / 1024 ** 3 | |
| ) | |
| ) | |
| # test for a few more rounds | |
| for _ in range(10): | |
| timer = Timer() | |
| max_iter = 1000 | |
| for _ in tqdm.trange(max_iter): | |
| next(itr) | |
| logger.info( | |
| "{} iters ({} images) in {} seconds.".format( | |
| max_iter, max_iter * cfg.SOLVER.IMS_PER_BATCH, timer.seconds() | |
| ) | |
| ) | |
| def benchmark_train(args): | |
| cfg = setup(args) | |
| model = build_model(cfg) | |
| logger.info("Model:\n{}".format(model)) | |
| if comm.get_world_size() > 1: | |
| model = DistributedDataParallel( | |
| model, device_ids=[comm.get_local_rank()], broadcast_buffers=False | |
| ) | |
| optimizer = build_optimizer(cfg, model) | |
| checkpointer = DetectionCheckpointer(model, optimizer=optimizer) | |
| checkpointer.load(cfg.MODEL.WEIGHTS) | |
| cfg.defrost() | |
| cfg.DATALOADER.NUM_WORKERS = 0 | |
| data_loader = build_detection_train_loader(cfg) | |
| dummy_data = list(itertools.islice(data_loader, 100)) | |
| def f(): | |
| data = DatasetFromList(dummy_data, copy=False) | |
| while True: | |
| yield from data | |
| max_iter = 400 | |
| trainer = SimpleTrainer(model, f(), optimizer) | |
| trainer.register_hooks( | |
| [hooks.IterationTimer(), hooks.PeriodicWriter([CommonMetricPrinter(max_iter)])] | |
| ) | |
| trainer.train(1, max_iter) | |
| def benchmark_eval(args): | |
| cfg = setup(args) | |
| model = build_model(cfg) | |
| model.eval() | |
| logger.info("Model:\n{}".format(model)) | |
| DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS) | |
| cfg.defrost() | |
| cfg.DATALOADER.NUM_WORKERS = 0 | |
| data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) | |
| dummy_data = list(itertools.islice(data_loader, 100)) | |
| def f(): | |
| while True: | |
| yield from DatasetFromList(dummy_data, copy=False) | |
| for _ in range(5): # warmup | |
| model(dummy_data[0]) | |
| max_iter = 400 | |
| timer = Timer() | |
| with tqdm.tqdm(total=max_iter) as pbar: | |
| for idx, d in enumerate(f()): | |
| if idx == max_iter: | |
| break | |
| model(d) | |
| pbar.update() | |
| logger.info("{} iters in {} seconds.".format(max_iter, timer.seconds())) | |
| if __name__ == "__main__": | |
| parser = default_argument_parser() | |
| parser.add_argument("--task", choices=["train", "eval", "data"], required=True) | |
| args = parser.parse_args() | |
| assert not args.eval_only | |
| if args.task == "data": | |
| f = benchmark_data | |
| elif args.task == "train": | |
| """ | |
| Note: training speed may not be representative. | |
| The training cost of a R-CNN model varies with the content of the data | |
| and the quality of the model. | |
| """ | |
| f = benchmark_train | |
| elif args.task == "eval": | |
| f = benchmark_eval | |
| # only benchmark single-GPU inference. | |
| assert args.num_gpus == 1 and args.num_machines == 1 | |
| launch(f, args.num_gpus, args.num_machines, args.machine_rank, args.dist_url, args=(args,)) | |