Spaces:
Build error
Build error
Commit
·
8728fb1
1
Parent(s):
d44c9b5
Update app.py
Browse files
app.py
CHANGED
|
@@ -31,20 +31,20 @@ def lab2rgb(L, AB):
|
|
| 31 |
rgb = color.lab2rgb(Lab) * 255
|
| 32 |
return rgb
|
| 33 |
|
| 34 |
-
def get_transform(params=None, grayscale=False, method=Image.BICUBIC):
|
| 35 |
#params
|
| 36 |
-
preprocess = '
|
| 37 |
load_size = 256
|
| 38 |
crop_size = 256
|
| 39 |
transform_list = []
|
| 40 |
if grayscale:
|
| 41 |
transform_list.append(transforms.Grayscale(1))
|
| 42 |
-
if
|
| 43 |
osize = [load_size, load_size]
|
| 44 |
transform_list.append(transforms.Resize(osize, method))
|
| 45 |
-
if 'crop' in preprocess:
|
| 46 |
-
|
| 47 |
-
|
| 48 |
|
| 49 |
return transforms.Compose(transform_list)
|
| 50 |
|
|
@@ -67,7 +67,7 @@ def inferRestoration(img, model_name):
|
|
| 67 |
return result
|
| 68 |
|
| 69 |
def inferColorization(img,model_name):
|
| 70 |
-
print(model_name)
|
| 71 |
if model_name == "Pix2Pix Resnet 9block":
|
| 72 |
model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixColorization_resnet9b')
|
| 73 |
elif model_name == "Pix2Pix Unet 256":
|
|
@@ -96,10 +96,12 @@ def inferColorization(img,model_name):
|
|
| 96 |
image_pil = transforms.ToPILImage()(result)
|
| 97 |
return image_pil
|
| 98 |
|
| 99 |
-
transform_seq = get_transform()
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
| 103 |
lab_t = transforms.ToTensor()(lab)
|
| 104 |
A = lab_t[[0], ...] / 50.0 - 1.0
|
| 105 |
B = lab_t[[1, 2], ...] / 110.0
|
|
@@ -160,4 +162,4 @@ examples = [['example/1.jpeg',"BOPBTL","Deoldify"],['example/2.jpg',"BOPBTL","De
|
|
| 160 |
iface = gr.Interface(run,
|
| 161 |
[gr.inputs.Image(),gr.inputs.Radio(["BOPBTL", "Pix2Pix"]),gr.inputs.Radio(["Deoldify", "Pix2Pix Resnet 9block","Pix2Pix Unet 256"])],
|
| 162 |
outputs="image",
|
| 163 |
-
examples=examples).launch(debug=True,share=
|
|
|
|
| 31 |
rgb = color.lab2rgb(Lab) * 255
|
| 32 |
return rgb
|
| 33 |
|
| 34 |
+
def get_transform(model_name,params=None, grayscale=False, method=Image.BICUBIC):
|
| 35 |
#params
|
| 36 |
+
preprocess = 'resize'
|
| 37 |
load_size = 256
|
| 38 |
crop_size = 256
|
| 39 |
transform_list = []
|
| 40 |
if grayscale:
|
| 41 |
transform_list.append(transforms.Grayscale(1))
|
| 42 |
+
if model_name == "Pix2Pix Unet 256":
|
| 43 |
osize = [load_size, load_size]
|
| 44 |
transform_list.append(transforms.Resize(osize, method))
|
| 45 |
+
# if 'crop' in preprocess:
|
| 46 |
+
# if params is None:
|
| 47 |
+
# transform_list.append(transforms.RandomCrop(crop_size))
|
| 48 |
|
| 49 |
return transforms.Compose(transform_list)
|
| 50 |
|
|
|
|
| 67 |
return result
|
| 68 |
|
| 69 |
def inferColorization(img,model_name):
|
| 70 |
+
#print(model_name)
|
| 71 |
if model_name == "Pix2Pix Resnet 9block":
|
| 72 |
model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixColorization_resnet9b')
|
| 73 |
elif model_name == "Pix2Pix Unet 256":
|
|
|
|
| 96 |
image_pil = transforms.ToPILImage()(result)
|
| 97 |
return image_pil
|
| 98 |
|
| 99 |
+
transform_seq = get_transform(model_name)
|
| 100 |
+
img = transform_seq(img)
|
| 101 |
+
# if model_name == "Pix2Pix Unet 256":
|
| 102 |
+
# img.resize((256,256))
|
| 103 |
+
img = np.array(img)
|
| 104 |
+
lab = color.rgb2lab(img).astype(np.float32)
|
| 105 |
lab_t = transforms.ToTensor()(lab)
|
| 106 |
A = lab_t[[0], ...] / 50.0 - 1.0
|
| 107 |
B = lab_t[[1, 2], ...] / 110.0
|
|
|
|
| 162 |
iface = gr.Interface(run,
|
| 163 |
[gr.inputs.Image(),gr.inputs.Radio(["BOPBTL", "Pix2Pix"]),gr.inputs.Radio(["Deoldify", "Pix2Pix Resnet 9block","Pix2Pix Unet 256"])],
|
| 164 |
outputs="image",
|
| 165 |
+
examples=examples).launch(debug=True,share=True)
|