Spaces:
Runtime error
Runtime error
| # MIT License | |
| # Copyright (c) 2022 Intelligent Systems Lab Org | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| # File author: Shariq Farooq Bhat | |
| import argparse | |
| from pprint import pprint | |
| import torch | |
| from zoedepth.utils.easydict import EasyDict as edict | |
| from tqdm import tqdm | |
| from zoedepth.data.data_mono import DepthDataLoader | |
| from zoedepth.models.builder import build_model | |
| from zoedepth.utils.arg_utils import parse_unknown | |
| from zoedepth.utils.config import change_dataset, get_config, ALL_EVAL_DATASETS, ALL_INDOOR, ALL_OUTDOOR | |
| from zoedepth.utils.misc import (RunningAverageDict, colors, compute_metrics, | |
| count_parameters) | |
| def infer(model, images, **kwargs): | |
| """Inference with flip augmentation""" | |
| # images.shape = N, C, H, W | |
| def get_depth_from_prediction(pred): | |
| if isinstance(pred, torch.Tensor): | |
| pred = pred # pass | |
| elif isinstance(pred, (list, tuple)): | |
| pred = pred[-1] | |
| elif isinstance(pred, dict): | |
| pred = pred['metric_depth'] if 'metric_depth' in pred else pred['out'] | |
| else: | |
| raise NotImplementedError(f"Unknown output type {type(pred)}") | |
| return pred | |
| pred1 = model(images, **kwargs) | |
| pred1 = get_depth_from_prediction(pred1) | |
| pred2 = model(torch.flip(images, [3]), **kwargs) | |
| pred2 = get_depth_from_prediction(pred2) | |
| pred2 = torch.flip(pred2, [3]) | |
| mean_pred = 0.5 * (pred1 + pred2) | |
| return mean_pred | |
| def evaluate(model, test_loader, config, round_vals=True, round_precision=3): | |
| model.eval() | |
| metrics = RunningAverageDict() | |
| for i, sample in tqdm(enumerate(test_loader), total=len(test_loader)): | |
| if 'has_valid_depth' in sample: | |
| if not sample['has_valid_depth']: | |
| continue | |
| image, depth = sample['image'], sample['depth'] | |
| image, depth = image.cuda(), depth.cuda() | |
| depth = depth.squeeze().unsqueeze(0).unsqueeze(0) | |
| focal = sample.get('focal', torch.Tensor( | |
| [715.0873]).cuda()) # This magic number (focal) is only used for evaluating BTS model | |
| pred = infer(model, image, dataset=sample['dataset'][0], focal=focal) | |
| # Save image, depth, pred for visualization | |
| if "save_images" in config and config.save_images: | |
| import os | |
| # print("Saving images ...") | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from zoedepth.utils.misc import colorize | |
| os.makedirs(config.save_images, exist_ok=True) | |
| # def save_image(img, path): | |
| d = colorize(depth.squeeze().cpu().numpy(), 0, 10) | |
| p = colorize(pred.squeeze().cpu().numpy(), 0, 10) | |
| im = transforms.ToPILImage()(image.squeeze().cpu()) | |
| im.save(os.path.join(config.save_images, f"{i}_img.png")) | |
| Image.fromarray(d).save(os.path.join(config.save_images, f"{i}_depth.png")) | |
| Image.fromarray(p).save(os.path.join(config.save_images, f"{i}_pred.png")) | |
| # print(depth.shape, pred.shape) | |
| metrics.update(compute_metrics(depth, pred, config=config)) | |
| if round_vals: | |
| def r(m): return round(m, round_precision) | |
| else: | |
| def r(m): return m | |
| metrics = {k: r(v) for k, v in metrics.get_value().items()} | |
| return metrics | |
| def main(config): | |
| model = build_model(config) | |
| test_loader = DepthDataLoader(config, 'online_eval').data | |
| model = model.cuda() | |
| metrics = evaluate(model, test_loader, config) | |
| print(f"{colors.fg.green}") | |
| print(metrics) | |
| print(f"{colors.reset}") | |
| metrics['#params'] = f"{round(count_parameters(model, include_all=True)/1e6, 2)}M" | |
| return metrics | |
| def eval_model(model_name, pretrained_resource, dataset='nyu', **kwargs): | |
| # Load default pretrained resource defined in config if not set | |
| overwrite = {**kwargs, "pretrained_resource": pretrained_resource} if pretrained_resource else kwargs | |
| config = get_config(model_name, "eval", dataset, **overwrite) | |
| # config = change_dataset(config, dataset) # change the dataset | |
| pprint(config) | |
| print(f"Evaluating {model_name} on {dataset}...") | |
| metrics = main(config) | |
| return metrics | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("-m", "--model", type=str, | |
| required=True, help="Name of the model to evaluate") | |
| parser.add_argument("-p", "--pretrained_resource", type=str, | |
| required=False, default=None, help="Pretrained resource to use for fetching weights. If not set, default resource from model config is used, Refer models.model_io.load_state_from_resource for more details.") | |
| parser.add_argument("-d", "--dataset", type=str, required=False, | |
| default='nyu', help="Dataset to evaluate on") | |
| args, unknown_args = parser.parse_known_args() | |
| overwrite_kwargs = parse_unknown(unknown_args) | |
| if "ALL_INDOOR" in args.dataset: | |
| datasets = ALL_INDOOR | |
| elif "ALL_OUTDOOR" in args.dataset: | |
| datasets = ALL_OUTDOOR | |
| elif "ALL" in args.dataset: | |
| datasets = ALL_EVAL_DATASETS | |
| elif "," in args.dataset: | |
| datasets = args.dataset.split(",") | |
| else: | |
| datasets = [args.dataset] | |
| for dataset in datasets: | |
| eval_model(args.model, pretrained_resource=args.pretrained_resource, | |
| dataset=dataset, **overwrite_kwargs) | |