Spaces:
Runtime error
Runtime error
| # -------------------------------------------------------- | |
| # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) | |
| # Github source: https://github.com/microsoft/unilm/tree/master/beit | |
| # Copyright (c) 2021 Microsoft | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # By Hangbo Bao | |
| # Based on timm, DINO and DeiT code bases | |
| # https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
| # https://github.com/facebookresearch/deit/ | |
| # https://github.com/facebookresearch/dino | |
| # --------------------------------------------------------' | |
| from timm.data import create_transform | |
| from timm.data.constants import \ | |
| IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD | |
| from timm.data.transforms import str_to_interp_mode | |
| from torchvision import transforms | |
| from dataset_folder import RvlcdipImageFolder | |
| def build_dataset(is_train, args): | |
| transform = build_transform(is_train, args) | |
| print("Transform = ") | |
| if isinstance(transform, tuple): | |
| for trans in transform: | |
| print(" - - - - - - - - - - ") | |
| for t in trans.transforms: | |
| print(t) | |
| else: | |
| for t in transform.transforms: | |
| print(t) | |
| print("---------------------------") | |
| if args.data_set == 'rvlcdip': | |
| root = args.data_path if is_train else args.eval_data_path | |
| split = "train" if is_train else "test" | |
| dataset = RvlcdipImageFolder(root, split=split, transform=transform) | |
| nb_classes = args.nb_classes | |
| assert len(dataset.class_to_idx) == nb_classes | |
| else: | |
| raise NotImplementedError() | |
| assert nb_classes == args.nb_classes | |
| print("Number of the class = %d" % args.nb_classes) | |
| return dataset, nb_classes | |
| def build_transform(is_train, args): | |
| resize_im = args.input_size > 32 | |
| imagenet_default_mean_and_std = args.imagenet_default_mean_and_std | |
| mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN | |
| std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD | |
| if is_train: | |
| # this should always dispatch to transforms_imagenet_train | |
| transform = create_transform( | |
| input_size=args.input_size, | |
| is_training=True, | |
| color_jitter=args.color_jitter, | |
| auto_augment=args.aa, | |
| interpolation=args.train_interpolation, | |
| re_prob=args.reprob, | |
| re_mode=args.remode, | |
| re_count=args.recount, | |
| mean=mean, | |
| std=std, | |
| ) | |
| if not resize_im: | |
| # replace RandomResizedCropAndInterpolation with | |
| # RandomCrop | |
| transform.transforms[0] = transforms.RandomCrop( | |
| args.input_size, padding=4) | |
| return transform | |
| t = [] | |
| if resize_im: | |
| if args.crop_pct is None: | |
| if args.input_size < 384: | |
| args.crop_pct = 224 / 256 | |
| else: | |
| args.crop_pct = 1.0 | |
| size = int(args.input_size / args.crop_pct) | |
| t.append( | |
| transforms.Resize(size, interpolation=str_to_interp_mode("bicubic")), | |
| # to maintain same ratio w.r.t. 224 images | |
| ) | |
| t.append(transforms.CenterCrop(args.input_size)) | |
| t.append(transforms.ToTensor()) | |
| t.append(transforms.Normalize(mean, std)) | |
| return transforms.Compose(t) | |