Spaces:
Build error
Build error
Upload 4 files
Browse files- SuperResolution.py +7 -8
- Underwater.py +6 -7
SuperResolution.py
CHANGED
@@ -8,7 +8,7 @@ import torchvision
|
|
8 |
import argparse
|
9 |
from models.SCET import SCET
|
10 |
|
11 |
-
def inference_img(img_path,Net
|
12 |
|
13 |
low_image = Image.open(img_path).convert('RGB')
|
14 |
enhance_transforms = transforms.Compose([
|
@@ -19,7 +19,7 @@ def inference_img(img_path,Net,device):
|
|
19 |
low_image = enhance_transforms(low_image)
|
20 |
low_image = low_image.unsqueeze(0)
|
21 |
start = time.time()
|
22 |
-
restored2 = Net(low_image
|
23 |
end = time.time()
|
24 |
|
25 |
|
@@ -31,17 +31,16 @@ if __name__ == '__main__':
|
|
31 |
parser.add_argument('--save_path',type=str,required=True,help='Path to save')
|
32 |
parser.add_argument('--pk_path',type=str,default='model_zoo/SRx4.pth',help='Path of the checkpoint')
|
33 |
parser.add_argument('--scale',type=int,default=4,help='scale factor')
|
34 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
35 |
opt = parser.parse_args()
|
36 |
if not os.path.isdir(opt.save_path):
|
37 |
os.mkdir(opt.save_path)
|
38 |
if opt.scale == 3:
|
39 |
-
Net = SCET(63, 128, opt.scale)
|
40 |
else:
|
41 |
-
Net = SCET(64, 128, opt.scale)
|
42 |
-
Net.load_state_dict(torch.load(opt.pk_path))
|
43 |
-
Net=Net.
|
44 |
image=opt.test_path
|
45 |
print(image)
|
46 |
-
restored2,time_num=inference_img(image,Net
|
47 |
torchvision.utils.save_image(restored2,opt.save_path+'output.png')
|
|
|
8 |
import argparse
|
9 |
from models.SCET import SCET
|
10 |
|
11 |
+
def inference_img(img_path,Net):
|
12 |
|
13 |
low_image = Image.open(img_path).convert('RGB')
|
14 |
enhance_transforms = transforms.Compose([
|
|
|
19 |
low_image = enhance_transforms(low_image)
|
20 |
low_image = low_image.unsqueeze(0)
|
21 |
start = time.time()
|
22 |
+
restored2 = Net(low_image)
|
23 |
end = time.time()
|
24 |
|
25 |
|
|
|
31 |
parser.add_argument('--save_path',type=str,required=True,help='Path to save')
|
32 |
parser.add_argument('--pk_path',type=str,default='model_zoo/SRx4.pth',help='Path of the checkpoint')
|
33 |
parser.add_argument('--scale',type=int,default=4,help='scale factor')
|
|
|
34 |
opt = parser.parse_args()
|
35 |
if not os.path.isdir(opt.save_path):
|
36 |
os.mkdir(opt.save_path)
|
37 |
if opt.scale == 3:
|
38 |
+
Net = SCET(63, 128, opt.scale)
|
39 |
else:
|
40 |
+
Net = SCET(64, 128, opt.scale)
|
41 |
+
Net.load_state_dict(torch.load(opt.pk_path, map_location=torch.device('cpu')))
|
42 |
+
Net=Net.eval()
|
43 |
image=opt.test_path
|
44 |
print(image)
|
45 |
+
restored2,time_num=inference_img(image,Net)
|
46 |
torchvision.utils.save_image(restored2,opt.save_path+'output.png')
|
Underwater.py
CHANGED
@@ -11,7 +11,7 @@ import torch.functional as F
|
|
11 |
import argparse
|
12 |
from net.Ushape_Trans import *
|
13 |
|
14 |
-
def inference_img(img_path,Net
|
15 |
|
16 |
low_image = Image.open(img_path).convert('RGB')
|
17 |
enhance_transforms = transforms.Compose([
|
@@ -23,7 +23,7 @@ def inference_img(img_path,Net,device):
|
|
23 |
low_image = enhance_transforms(low_image)
|
24 |
low_image = low_image.unsqueeze(0)
|
25 |
start = time.time()
|
26 |
-
restored2 = Net(low_image
|
27 |
end = time.time()
|
28 |
|
29 |
|
@@ -37,11 +37,10 @@ if __name__ == '__main__':
|
|
37 |
opt = parser.parse_args()
|
38 |
if not os.path.isdir(opt.save_path):
|
39 |
os.mkdir(opt.save_path)
|
40 |
-
|
41 |
-
Net =
|
42 |
-
Net.
|
43 |
-
Net = Net.to(device)
|
44 |
image = opt.test_path
|
45 |
print(image)
|
46 |
-
restored2,time_num = inference_img(image,Net
|
47 |
torchvision.utils.save_image(restored2,opt.save_path+'output.png')
|
|
|
11 |
import argparse
|
12 |
from net.Ushape_Trans import *
|
13 |
|
14 |
+
def inference_img(img_path,Net):
|
15 |
|
16 |
low_image = Image.open(img_path).convert('RGB')
|
17 |
enhance_transforms = transforms.Compose([
|
|
|
23 |
low_image = enhance_transforms(low_image)
|
24 |
low_image = low_image.unsqueeze(0)
|
25 |
start = time.time()
|
26 |
+
restored2 = Net(low_image)
|
27 |
end = time.time()
|
28 |
|
29 |
|
|
|
37 |
opt = parser.parse_args()
|
38 |
if not os.path.isdir(opt.save_path):
|
39 |
os.mkdir(opt.save_path)
|
40 |
+
Net = Generator()
|
41 |
+
Net.load_state_dict(torch.load(opt.pk_path, map_location=torch.device('cpu')))
|
42 |
+
Net = Net.eval()
|
|
|
43 |
image = opt.test_path
|
44 |
print(image)
|
45 |
+
restored2,time_num = inference_img(image,Net)
|
46 |
torchvision.utils.save_image(restored2,opt.save_path+'output.png')
|