import os import torch import argparse import numpy as np from PIL import Image from tqdm import tqdm from torch.utils.data import DataLoader import torchvision.transforms as transforms import networks from utils.transforms import transform_logits from datasets.simple_extractor_dataset import SimpleFileDataset # Modify dataset class to use SimpleFileDataset dataset_settings = { 'atr': { 'input_size': [512, 512], 'num_classes': 18, 'label': ['Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt', 'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf'] } } def get_arguments(): parser = argparse.ArgumentParser(description="Self Correction for Human Parsing") parser.add_argument("--dataset", type=str, default='atr', choices=['atr']) parser.add_argument("--model-restore", type=str, default='', help="Path to pretrained model.") parser.add_argument("--gpu", type=str, default='0', help="GPU device.") parser.add_argument("--input-path", type=str, default='', help="Path to a single input image.") parser.add_argument("--output-dir", type=str, default='', help="Path of output image folder.") return parser.parse_args() def main(): args = get_arguments() os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu num_classes = dataset_settings[args.dataset]['num_classes'] input_size = dataset_settings[args.dataset]['input_size'] model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None) state_dict = torch.load(args.model_restore)['state_dict'] from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # Remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) model.cuda() model.eval() transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]) ]) # Use the SimpleFileDataset class instead of SimpleFolderDataset dataset = SimpleFileDataset(img_path=args.input_path, input_size=input_size, transform=transform) dataloader = DataLoader(dataset, batch_size=1) # Only one image, so batch_size=1 if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) with torch.no_grad(): for idx, batch in enumerate(tqdm(dataloader)): image, meta = batch img_name = meta['name'][0] output = model(image.cuda()) upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True) upsample_output = upsample(output[0][-1][0].unsqueeze(0)) upsample_output = upsample_output.squeeze().permute(1, 2, 0) # CHW -> HWC parsing_result = np.argmax(upsample_output.cpu().numpy(), axis=2) # Get binary mask for Upper Clothes, Left Arm, and Right Arm mask = np.isin(parsing_result, [4, 14, 15]) * 255 # Convert to binary (255 for selected parts) mask_img = Image.fromarray(mask.astype(np.uint8)) # Define the new size (width, height) new_size = (768, 1024) # Resize the image mask_img = mask_img.resize(new_size) mask_img.save(os.path.join(args.output_dir, img_name[:-4] + ".png")) if __name__ == '__main__': main()