Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| from typing import Any, List, Optional, Tuple | |
| import torch | |
| import torch.backends.cudnn as cudnn | |
| from dinov2.models import build_model_from_cfg | |
| from dinov2.utils.config import setup | |
| import dinov2.utils.utils as dinov2_utils | |
| def get_args_parser( | |
| description: Optional[str] = None, | |
| parents: Optional[List[argparse.ArgumentParser]] = None, | |
| add_help: bool = True, | |
| ): | |
| parser = argparse.ArgumentParser( | |
| description=description, | |
| parents=parents or [], | |
| add_help=add_help, | |
| ) | |
| parser.add_argument( | |
| "--config-file", | |
| type=str, | |
| help="Model configuration file", | |
| ) | |
| parser.add_argument( | |
| "--pretrained-weights", | |
| type=str, | |
| help="Pretrained model weights", | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| default="", | |
| type=str, | |
| help="Output directory to write results and logs", | |
| ) | |
| parser.add_argument( | |
| "--opts", | |
| help="Extra configuration options", | |
| default=[], | |
| nargs="+", | |
| ) | |
| return parser | |
| def get_autocast_dtype(config): | |
| teacher_dtype_str = config.compute_precision.teacher.backbone.mixed_precision.param_dtype | |
| if teacher_dtype_str == "fp16": | |
| return torch.half | |
| elif teacher_dtype_str == "bf16": | |
| return torch.bfloat16 | |
| else: | |
| return torch.float | |
| def build_model_for_eval(config, pretrained_weights): | |
| model, _ = build_model_from_cfg(config, only_teacher=True) | |
| dinov2_utils.load_pretrained_weights(model, pretrained_weights, "teacher") | |
| model.eval() | |
| model.cuda() | |
| return model | |
| def setup_and_build_model(args) -> Tuple[Any, torch.dtype]: | |
| cudnn.benchmark = True | |
| config = setup(args) | |
| model = build_model_for_eval(config, args.pretrained_weights) | |
| autocast_dtype = get_autocast_dtype(config) | |
| return model, autocast_dtype | |