Deploy_Restoration / SuperResolution.py
AlexZou's picture
Upload 4 files
7970501
import os
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
import time
import torchvision
import argparse
from models.SCET import SCET
def inference_img(img_path,Net):
low_image = Image.open(img_path).convert('RGB')
enhance_transforms = transforms.Compose([
transforms.ToTensor()
])
with torch.no_grad():
low_image = enhance_transforms(low_image)
low_image = low_image.unsqueeze(0)
start = time.time()
restored2 = Net(low_image)
end = time.time()
return restored2,end-start
if __name__ == '__main__':
parser=argparse.ArgumentParser()
parser.add_argument('--test_path',type=str,required=True,help='Path to test')
parser.add_argument('--save_path',type=str,required=True,help='Path to save')
parser.add_argument('--pk_path',type=str,default='model_zoo/SRx4.pth',help='Path of the checkpoint')
parser.add_argument('--scale',type=int,default=4,help='scale factor')
opt = parser.parse_args()
if not os.path.isdir(opt.save_path):
os.mkdir(opt.save_path)
if opt.scale == 3:
Net = SCET(63, 128, opt.scale)
else:
Net = SCET(64, 128, opt.scale)
Net.load_state_dict(torch.load(opt.pk_path, map_location=torch.device('cpu')))
Net=Net.eval()
image=opt.test_path
print(image)
restored2,time_num=inference_img(image,Net)
torchvision.utils.save_image(restored2,opt.save_path+'output.png')