rizavelioglu commited on
Commit
7c23ab5
·
0 Parent(s):

Initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/01_1.jpg filter=lfs diff=lfs merge=lfs -text
37
+ examples/02_1.jpg filter=lfs diff=lfs merge=lfs -text
38
+ examples/03_1.jpg filter=lfs diff=lfs merge=lfs -text
39
+ examples/05_1.jpg filter=lfs diff=lfs merge=lfs -text
40
+ examples/06_1.jpg filter=lfs diff=lfs merge=lfs -text
41
+ examples/08_1.jpg filter=lfs diff=lfs merge=lfs -text
42
+ examples/10_1.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: FR IQA
3
+ emoji: 🌖
4
+ colorFrom: gray
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 5.42.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: Compute similarity between two images using FR-IQA metrics
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from DISTS_pytorch import DISTS
4
+ from torchvision.io import read_image
5
+ import torch
6
+ import torchvision.transforms.v2 as transforms
7
+ import pyiqa
8
+ import spaces
9
+ from metrics.DeepDC import DeepDC
10
+ from metrics.DeepWSD import DeepWSD
11
+ from metrics.ADISTS import ADISTS
12
+
13
+ @spaces.GPU(duration=5)
14
+ class Evaluator:
15
+ def __init__(self, device):
16
+ self.device = device
17
+ self.metrics = {
18
+ "↓ MSE": torch.nn.functional.mse_loss,
19
+ "↓ L1": torch.nn.functional.l1_loss,
20
+ "↓ DISTS": DISTS().to(self.device),
21
+ "↓ LPIPS": pyiqa.create_metric("lpips", device=self.device),
22
+ "↑ PSNR": pyiqa.create_metric("psnr", device=self.device),
23
+ "↑ SSIM": pyiqa.create_metric("ssim", device=self.device),
24
+ "↑ MS-SSIM": pyiqa.create_metric("ms_ssim", device=self.device),
25
+ "↑ CW-SSIM": pyiqa.create_metric("cw_ssim", device=self.device),
26
+ "↑ FSIM": pyiqa.create_metric("fsim", device=self.device),
27
+ "↑ DeepDC": DeepDC().to(self.device),
28
+ "↑ DeepWSD": DeepWSD().to(self.device),
29
+ "↑ ADISTS": ADISTS().to(self.device),
30
+ }
31
+ self.transform = transforms.ToDtype(dtype=torch.float32, scale=True)
32
+
33
+ @torch.no_grad()
34
+ def evaluate(self, img_fname1, img_fname2):
35
+ img1 = self.transform(read_image(img_fname1)).unsqueeze(0).to(self.device)
36
+ img2 = self.transform(read_image(img_fname2)).unsqueeze(0).to(self.device)
37
+
38
+ # check images are the same size
39
+ if img1.shape != img2.shape:
40
+ raise gr.Error("Input images must have the same dimensions!")
41
+
42
+ return "\n".join(
43
+ f"{name:<10}: {float(metric(img1, img2).item()):3,.5f}"
44
+ for name, metric in self.metrics.items()
45
+ )
46
+
47
+
48
+ @spaces.GPU(duration=1)
49
+ def get_evaluator():
50
+ """Returns a singleton Evaluator instance per worker/session."""
51
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
52
+ if not hasattr(get_evaluator, "evaluator"):
53
+ get_evaluator.evaluator = Evaluator(device)
54
+ return get_evaluator.evaluator
55
+
56
+
57
+ @spaces.GPU(duration=5)
58
+ def compute_similarity(img_fname1, img_fname2):
59
+ return get_evaluator().evaluate(img_fname1, img_fname2)
60
+
61
+
62
+ if __name__ == "__main__":
63
+ examples = [
64
+ ["examples/01_1.jpg", "examples/01_1.jpg"], # Add an extra example for identical images
65
+ *[[f"examples/{i:02d}_1.jpg", f"examples/{i:02d}_2.jpg"] for i in range(1, 10)],
66
+ ]
67
+
68
+ # Gradio UI
69
+ custom_css = """
70
+ .center-header {
71
+ display: flex;
72
+ align-items: center;
73
+ justify-content: center;
74
+ margin: 0 0 10px 0;
75
+ }
76
+ .monospace-text {
77
+ font-family: 'Courier New', Courier, monospace;
78
+ }
79
+ """
80
+ with gr.Blocks(title="FR-IQA", css=custom_css) as demo:
81
+ gr.Markdown(f"""
82
+ <div class='center-header'><h1>FR-IQA</h1></div>
83
+ Upload two images to compute various Full-Reference IQA metrics for measuring similarity between them.<br>
84
+ <b>Note</b>: Images must be of the same size.
85
+ """)
86
+
87
+ with gr.Row():
88
+ with gr.Column(scale=2):
89
+ img_fname1 = gr.Image(type="filepath", label="Image#1", height=512, width=512)
90
+ with gr.Column(scale=2):
91
+ img_fname2 = gr.Image(type="filepath", label="Image#2", height=512, width=512)
92
+ with gr.Column(scale=1):
93
+ metrics_output = gr.Textbox(label="Metrics Output", lines=22, elem_classes="monospace-text", show_copy_button=True)
94
+
95
+ with gr.Row():
96
+ submit_btn = gr.Button("Compute!")
97
+
98
+ with gr.Row():
99
+ with gr.Column(scale=2):
100
+ gr.Examples(
101
+ examples=examples,
102
+ inputs=[img_fname1, img_fname2],
103
+ fn=compute_similarity,
104
+ outputs=metrics_output,
105
+ label="Example Image Pairs (all images are 1024×768)",
106
+ cache_examples=False,
107
+ examples_per_page=5
108
+ )
109
+ with gr.Column(scale=2):
110
+ gr.Markdown("""
111
+ <div class='center-header'><h3>Acknowledgements</h3></div>
112
+
113
+ - Example images are from the [TryOffDiff](https://rizavelioglu.github.io/tryoffdiff) paper, which are sampled from VITON-HD dataset.
114
+ - We use the [IQA-PyTorch](https://github.com/chaofengc/IQA-PyTorch) library for computing the metrics.
115
+
116
+
117
+ """)
118
+
119
+ submit_btn.click(
120
+ fn=compute_similarity,
121
+ inputs=[img_fname1, img_fname2],
122
+ outputs=[metrics_output]
123
+ )
124
+
125
+ demo.launch(share=False, ssr_mode=False)
examples/01_1.jpg ADDED

Git LFS Details

  • SHA256: 5f9250903aa2221e23b63ff3e7bb4dadcb2e907b4845802b1f755ad2e8ffd234
  • Pointer size: 131 Bytes
  • Size of remote file: 104 kB
examples/01_2.jpg ADDED
examples/02_1.jpg ADDED

Git LFS Details

  • SHA256: 8dafab5573b689bc870b43e646417e8d2ac2fce43654dc31afdef518654e667c
  • Pointer size: 131 Bytes
  • Size of remote file: 129 kB
examples/02_2.jpg ADDED
examples/03_1.jpg ADDED

Git LFS Details

  • SHA256: 5b3b5f77ec0617312718fc9be9207d009625fee7d0442178262a0a2e04e98243
  • Pointer size: 131 Bytes
  • Size of remote file: 183 kB
examples/03_2.jpg ADDED
examples/04_1.jpg ADDED
examples/04_2.jpg ADDED
examples/05_1.jpg ADDED

Git LFS Details

  • SHA256: 4a34eaef0a4f12ff22850c6b309af64b9b32c321c2df78d1d9416028cd0b8efb
  • Pointer size: 131 Bytes
  • Size of remote file: 136 kB
examples/05_2.jpg ADDED
examples/06_1.jpg ADDED
examples/06_2.jpg ADDED
examples/07_1.jpg ADDED

Git LFS Details

  • SHA256: 3933ca6edf8e3cc8784448f762c82de89e69cb126936c7a52d76de4e73678386
  • Pointer size: 131 Bytes
  • Size of remote file: 106 kB
examples/07_2.jpg ADDED
examples/08_1.jpg ADDED
examples/08_2.jpg ADDED
examples/09_1.jpg ADDED

Git LFS Details

  • SHA256: 8990cc2a01d39dc6c325e63647c826f197e7ddd9d6b82e7ebadd2da4f5c61b9e
  • Pointer size: 131 Bytes
  • Size of remote file: 104 kB
examples/09_2.jpg ADDED
metrics/ADISTS.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Taken from:
2
+ https://github.com/dingkeyan93/A-DISTS/blob/3d20592648625df2e451c9aba25bbaf3c7952ac8/A-DISTS.py
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ import torchvision
8
+ from torchvision import models,transforms
9
+ from torch.nn.functional import normalize
10
+ import torch.nn.functional as F
11
+ import math
12
+
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+
15
+
16
+ class Downsample(nn.Module):
17
+ def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):
18
+ super(Downsample, self).__init__()
19
+ self.padding = (filter_size - 2 )//2
20
+ self.stride = stride
21
+ self.channels = channels
22
+ a = np.hanning(filter_size)[1:-1]
23
+ g = torch.Tensor(a[:,None]*a[None,:])
24
+ g = g/torch.sum(g)
25
+ self.register_buffer('filter', g[None,None,:,:].repeat((self.channels,1,1,1)))
26
+ # print (g)
27
+
28
+ def forward(self, input):
29
+ input = input**2
30
+ out = F.conv2d(input, self.filter, stride=self.stride, padding=self.padding, groups=input.shape[1])
31
+ return (out+1e-12).sqrt()
32
+
33
+ class ADISTS(torch.nn.Module):
34
+ def __init__(self, window_size=21):
35
+ super(ADISTS, self).__init__()
36
+ vgg_pretrained_features = models.vgg16(pretrained=True).features
37
+
38
+ self.stage1 = torch.nn.Sequential()
39
+ self.stage2 = torch.nn.Sequential()
40
+ self.stage3 = torch.nn.Sequential()
41
+ self.stage4 = torch.nn.Sequential()
42
+ self.stage5 = torch.nn.Sequential()
43
+
44
+ for x in range(0,4):
45
+ self.stage1.add_module(str(x), vgg_pretrained_features[x])
46
+ self.stage2.add_module(str(4), Downsample(channels=64))
47
+ for x in range(5, 9):
48
+ self.stage2.add_module(str(x), vgg_pretrained_features[x])
49
+ self.stage3.add_module(str(9), Downsample(channels=128))
50
+ for x in range(10, 16):
51
+ self.stage3.add_module(str(x), vgg_pretrained_features[x])
52
+ self.stage4.add_module(str(16), Downsample(channels=256))
53
+ for x in range(17, 23):
54
+ self.stage4.add_module(str(x), vgg_pretrained_features[x])
55
+ self.stage5.add_module(str(23), Downsample(channels=512))
56
+ for x in range(24, 30):
57
+ self.stage5.add_module(str(x), vgg_pretrained_features[x])
58
+
59
+ for param in self.parameters():
60
+ param.requires_grad = False
61
+
62
+ self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1,-1,1,1))
63
+ self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1,-1,1,1))
64
+
65
+ self.chns = [3,64,128,256,512,512]
66
+ self.windows=nn.ParameterList()
67
+ self.window_size = window_size
68
+ for k in range(len(self.chns)):
69
+ self.windows.append(self.create_window(self.window_size, self.window_size/3, self.chns[k]))
70
+
71
+
72
+ def compute_prob(self, x, k):
73
+
74
+ theta = [[0, 0],
75
+ [1.0, 0.29],
76
+ [2.0, 0.52],
77
+ [2.95, 0.56],
78
+ [0.97, 0.25],
79
+ [0.21, 0.10]]
80
+
81
+ ps = 1/(1+torch.exp(-(x-theta[k][0])/theta[k][1]))
82
+ pt = 1 - ps
83
+ return ps, pt
84
+
85
+
86
+ def gaussian(self,window_size, sigma):
87
+ gauss = torch.Tensor([math.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
88
+ return gauss/gauss.sum()
89
+
90
+ def create_window(self,window_size, window_sigma, channel):
91
+ _1D_window = self.gaussian(window_size, window_sigma).unsqueeze(1)
92
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
93
+ window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
94
+ return nn.Parameter(window,requires_grad=False)
95
+
96
+ def forward_once(self, x):
97
+ h = (x-self.mean)/self.std
98
+ h = self.stage1(h)
99
+ h_relu1_2 = h
100
+ h = self.stage2(h)
101
+ h_relu2_2 = h
102
+ h = self.stage3(h)
103
+ h_relu3_3 = h
104
+ h = self.stage4(h)
105
+ h_relu4_3 = h
106
+ if len(self.chns)==6:
107
+ h = self.stage5(h)
108
+ h_relu5_3 = h
109
+ outs = [x,h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]
110
+ else:
111
+ outs = [x,h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3]
112
+ return outs
113
+
114
+ def forward(self, x, y, as_loss=False):
115
+ assert x.shape == y.shape
116
+ if as_loss:
117
+ feats_x = self.forward_once(x)
118
+ feats_y = self.forward_once(y)
119
+ else:
120
+ with torch.no_grad():
121
+ feats_x = self.forward_once(x)
122
+ feats_y = self.forward_once(y)
123
+
124
+ D = 0
125
+ c1 = 1e-6
126
+ c2 = 1e-6
127
+ pad = nn.ReflectionPad2d(0)
128
+ for k in range(len(self.chns)-1,-1,-1):
129
+ try:
130
+ x_mean = F.conv2d(pad(feats_x[k]), self.windows[k], stride =1, padding = 0, groups = self.chns[k])
131
+ y_mean = F.conv2d(pad(feats_y[k]), self.windows[k], stride =1, padding = 0, groups = self.chns[k])
132
+ x_var = F.conv2d(pad(feats_x[k]**2), self.windows[k], stride =1, padding = 0, groups = self.chns[k]) - x_mean**2
133
+ y_var = F.conv2d(pad(feats_y[k]**2), self.windows[k], stride =1, padding = 0, groups = self.chns[k]) - y_mean**2
134
+ xy_cov = F.conv2d(pad(feats_x[k]*feats_y[k]), self.windows[k], stride =1, padding = 0, groups = self.chns[k]) - x_mean*y_mean
135
+ except:
136
+ x_mean = feats_x[k].mean([2,3], keepdim=True)
137
+ y_mean = feats_y[k].mean([2,3], keepdim=True)
138
+ x_var = ((feats_x[k]-x_mean)**2).mean([2,3], keepdim=True)
139
+ y_var = ((feats_y[k]-y_mean)**2).mean([2,3], keepdim=True)
140
+ xy_cov = (feats_x[k]*feats_y[k]).mean([2,3],keepdim=True) - x_mean*y_mean
141
+
142
+ T = (2*x_mean*y_mean+c1)/(x_mean**2+y_mean**2+c1)
143
+ S = (2*xy_cov+c2)/(x_var+y_var+c2)
144
+
145
+ if k>0:
146
+ ratio = torch.mean(x_var/(x_mean+1e-12),dim=1,keepdim=True)
147
+ ps, pt = self.compute_prob(ratio,k)
148
+
149
+ D_map = pt*T+ps*S
150
+ # D = D + D_map.mean([2,3]).mean(1)/len(self.chns)
151
+ D = D + D_map.mean([2,3]).sum(1)/sum(self.chns)
152
+
153
+ if as_loss:
154
+ return 1-D.mean()
155
+ else:
156
+ return 1-D
157
+
metrics/DeepDC.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Taken from: https://github.com/h4nwei/DeepDC
2
+ """
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Optional
6
+ from collections import OrderedDict
7
+ import numpy as np
8
+ import torch.nn.functional as F
9
+ import math
10
+ import torchvision
11
+ from torchvision import models, transforms
12
+
13
+
14
+ names = {'vgg19': ['image', 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
15
+ 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
16
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
17
+ 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
18
+ 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2',
19
+ 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
20
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
21
+ 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'],}
22
+
23
+ class MultiVGGFeaturesExtractor(nn.Module):
24
+ def __init__(self, target_features=('conv1_2', 'conv2_2', 'conv3_4', 'conv4_4', 'conv5_4'), use_input_norm=False, requires_grad=False): # ALL FALSE is the best for COS_Similarity; Correlation: use_norm = True
25
+ super(MultiVGGFeaturesExtractor, self).__init__()
26
+ self.use_input_norm = use_input_norm
27
+ self.target_features = target_features
28
+
29
+
30
+ model = torchvision.models.vgg19(pretrained=True)
31
+ names_key = 'vgg19'
32
+
33
+ if self.use_input_norm:
34
+ mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
35
+ std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
36
+ self.register_buffer('mean', mean)
37
+ self.register_buffer('std', std)
38
+
39
+ self.target_indexes = [names[names_key].index(k) - 1 for k in self.target_features]
40
+ self.features = nn.Sequential(*list(model.features.children())[:(max(self.target_indexes) + 1)])
41
+
42
+ if not requires_grad:
43
+ for k, v in self.features.named_parameters():
44
+ v.requires_grad = False
45
+ self.features.eval()
46
+
47
+ def forward(self, x):
48
+ # assume input range is [0, 1]
49
+ if self.use_input_norm:
50
+ x = (x - self.mean) / self.std
51
+
52
+ y = OrderedDict()
53
+ if 'image' in self.target_features:
54
+ y.update({"image": x})
55
+ for key, layer in self.features._modules.items():
56
+ x = layer(x)
57
+ # x = self._normalize_tensor(x)
58
+ if int(key) in self.target_indexes:
59
+ y.update({self.target_features[self.target_indexes.index(int(key))]: x})
60
+ return y
61
+
62
+ def _normalize_tensor(sefl, in_feat, eps=1e-10):
63
+ norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True))
64
+ return in_feat / (norm_factor + eps)
65
+
66
+
67
+ class DeepDC(nn.Module):
68
+ def __init__(self, features_to_compute=('conv1_2', 'conv2_2', 'conv3_4', 'conv4_4', 'conv5_4')):
69
+ super(DeepDC, self).__init__()
70
+ self.MSE = torch.nn.MSELoss()
71
+ self.features_extractor = MultiVGGFeaturesExtractor(target_features=features_to_compute).eval()
72
+
73
+ def forward(self, x, y):
74
+ r"""Compute IQA using DeepDC model.
75
+
76
+ Args:
77
+ - x: An input tensor with (N, C, H, W) shape. RGB channel order for colour images.
78
+ - y: An reference tensor with (N, C, H, W) shape. RGB channel order for colour images.
79
+
80
+ Returns:
81
+ Value of DeepDC model.
82
+
83
+ """
84
+ targets, inputs = x, y
85
+ inputs_fea = self.features_extractor(inputs)
86
+
87
+ with torch.no_grad():
88
+ targets_fea =self.features_extractor(targets)
89
+
90
+ dc_scores = []
91
+
92
+ for _, key in enumerate(inputs_fea.keys()):
93
+ inputs_dcdm = self._DCDM(inputs_fea[key])
94
+ targets_dcdm = self._DCDM(targets_fea[key])
95
+ dc_scores.append(self.Distance_Correlation(inputs_dcdm, targets_dcdm))
96
+
97
+ dc_scores = torch.stack(dc_scores, dim=1)
98
+
99
+ score = 1 - dc_scores.mean(dim=1, keepdim=True)
100
+
101
+ return score
102
+
103
+
104
+ # double-centered distance matrix (dcdm)
105
+ def _DCDM(self, x):
106
+ if len(x.shape)==4:
107
+ batchSize, dim, h, w = x.data.shape
108
+ M = h * w
109
+ elif len(x.shape)==3:
110
+ batchSize, M, dim = x.data.shape
111
+ x = x.reshape(batchSize, dim, M)
112
+ t = torch.log((1. / (torch.tensor(dim) * torch.tensor(dim))) )
113
+
114
+ I = torch.eye(dim, dim, device=x.device).view(1, dim, dim).repeat(batchSize, 1, 1).type(x.dtype)
115
+ I_M = torch.ones(batchSize, dim, dim, device=x.device).type(x.dtype)
116
+ x_pow2 = x.bmm(x.transpose(1, 2))
117
+ dcov = I_M.bmm(x_pow2 * I) + (x_pow2 * I).bmm(I_M) - 2 * x_pow2
118
+
119
+ dcov = torch.clamp(dcov, min=0.0)
120
+ dcov = torch.exp(t)* dcov
121
+ dcov = torch.sqrt(dcov + 1e-5)
122
+ dcdm = dcov - 1. / dim * dcov.bmm(I_M) - 1. / dim * I_M.bmm(dcov) + 1. / (dim * dim) * I_M.bmm(dcov).bmm(I_M)
123
+
124
+ return dcdm
125
+
126
+
127
+ def Distance_Correlation(self, matrix_A, matrix_B):
128
+
129
+ Gamma_XY = torch.sum(matrix_A * matrix_B, dim=[1,2])
130
+ Gamma_XX = torch.sum(matrix_A * matrix_A, dim=[1,2])
131
+ Gamma_YY = torch.sum(matrix_B * matrix_B, dim=[1,2])
132
+ c = 1e-6
133
+ correlation_r = (Gamma_XY + c) / (torch.sqrt(Gamma_XX * Gamma_YY) + c)
134
+ return correlation_r
metrics/DeepWSD.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Taken from: https://github.com/Buka-Xing/DeepWSD/blob/main/utils.py
2
+ """
3
+ import numpy as np
4
+ import os
5
+ import sys
6
+ import torch
7
+ from torchvision import models,transforms
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import inspect
11
+ from ot.lp import wasserstein_1d
12
+
13
+
14
+ # Process input of VGG16 to make it close to 256
15
+ def downsample(img1, img2, maxSize = 256):
16
+ _,channels,H,W = img1.shape
17
+ f = int(max(1,np.round(max(H,W)/maxSize)))
18
+
19
+ aveKernel = (torch.ones(channels,1,f,f)/f**2).to(img1.device)
20
+ img1 = F.conv2d(img1, aveKernel, stride=f, padding = 0, groups = channels)
21
+ img2 = F.conv2d(img2, aveKernel, stride=f, padding = 0, groups = channels)
22
+ # For an extremely Large image, the larger window will use to increase the receptive field.
23
+ if f >= 5:
24
+ win = 16
25
+ else:
26
+ win = 4
27
+ return img1, img2, win, f
28
+
29
+
30
+ # Use L2pooling for VGG16 networks.
31
+ # Original Maxpooling will generate distortions in color channels during optimization.
32
+ class L2pooling(nn.Module):
33
+ def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):
34
+ super(L2pooling, self).__init__()
35
+ self.padding = (filter_size - 2 )//2
36
+ self.stride = stride
37
+ self.channels = channels
38
+ a = np.hanning(filter_size)[1:-1]
39
+
40
+ g = torch.Tensor(a[:,None]*a[None,:])
41
+ g = g/torch.sum(g)
42
+ self.register_buffer('filter', g[None,None,:,:].repeat((self.channels,1,1,1)))
43
+
44
+ def forward(self, input):
45
+ input = input**2
46
+ out = F.conv2d(input, self.filter, stride=self.stride, padding=self.padding, groups=input.shape[1])
47
+ return (out+1e-12).sqrt()
48
+
49
+
50
+ def ws_distance(X,Y,P=2,win=4):
51
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
52
+
53
+ chn_num = X.shape[1]
54
+ X_sum = X.sum().sum()
55
+ Y_sum = Y.sum().sum()
56
+
57
+ X_patch = torch.reshape(X,[win,win,chn_num,-1])
58
+ Y_patch = torch.reshape(Y,[win,win,chn_num,-1])
59
+ patch_num = (X.shape[2]//win) * (X.shape[3]//win)
60
+
61
+ X_1D = torch.reshape(X_patch,[-1,chn_num*patch_num])
62
+ Y_1D = torch.reshape(Y_patch,[-1,chn_num*patch_num])
63
+
64
+ X_1D_pdf = X_1D / (X_sum + 1e-6)
65
+ Y_1D_pdf = Y_1D / (Y_sum + 1e-6)
66
+
67
+ interval = np.arange(0, X_1D.shape[0], 1)
68
+ all_samples = torch.from_numpy(interval).to(device).repeat([patch_num*chn_num,1]).t()
69
+
70
+ X_pdf = X_1D * X_1D_pdf
71
+ Y_pdf = Y_1D * Y_1D_pdf
72
+
73
+ wsd = wasserstein_1d(all_samples, all_samples, X_pdf, Y_pdf, P)
74
+
75
+ L2 = ((X_1D - Y_1D) ** 2).sum(dim=0)
76
+ w = (1 / ( torch.sqrt(torch.exp( (- 1/(wsd+10) ))) * (wsd+10)**2))
77
+
78
+ final = wsd + L2 * w
79
+ # final = wsd
80
+
81
+ return final.sum()
82
+
83
+
84
+ class DeepWSD(torch.nn.Module):
85
+
86
+ def __init__(self, channels=3, load_weights=True):
87
+ assert channels == 3
88
+ super(DeepWSD, self).__init__()
89
+ self.window = 4
90
+
91
+ vgg_pretrained_features = models.vgg16(pretrained=True).features
92
+ self.stage1 = torch.nn.Sequential()
93
+ self.stage2 = torch.nn.Sequential()
94
+ self.stage3 = torch.nn.Sequential()
95
+ self.stage4 = torch.nn.Sequential()
96
+ self.stage5 = torch.nn.Sequential()
97
+
98
+ # Rewrite the output layer of every block in the VGG network: maxpool->l2pool
99
+ for x in range(0, 4):
100
+ self.stage1.add_module(str(x), vgg_pretrained_features[x])
101
+ self.stage2.add_module(str(4), L2pooling(channels=64))
102
+ for x in range(5, 9):
103
+ self.stage2.add_module(str(x), vgg_pretrained_features[x])
104
+ self.stage3.add_module(str(9), L2pooling(channels=128))
105
+ for x in range(10, 16):
106
+ self.stage3.add_module(str(x), vgg_pretrained_features[x])
107
+ self.stage4.add_module(str(16), L2pooling(channels=256))
108
+ for x in range(17, 23):
109
+ self.stage4.add_module(str(x), vgg_pretrained_features[x])
110
+ self.stage5.add_module(str(23), L2pooling(channels=512))
111
+ for x in range(24, 30):
112
+ self.stage5.add_module(str(x), vgg_pretrained_features[x])
113
+
114
+ for param in self.parameters():
115
+ param.requires_grad = False
116
+
117
+ self.chns = [3, 64, 128, 256, 512, 512]
118
+
119
+ def forward_once(self, x):
120
+ h = x
121
+ h = self.stage1(h)
122
+ h_relu1_2 = h
123
+ h = self.stage2(h)
124
+ h_relu2_2 = h
125
+ h = self.stage3(h)
126
+ h_relu3_3 = h
127
+ h = self.stage4(h)
128
+ h_relu4_3 = h
129
+ h = self.stage5(h)
130
+ h_relu5_3 = h
131
+ return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]
132
+
133
+ def forward(self, x, y, as_loss=False, resize=True):
134
+ assert x.shape == y.shape
135
+ if resize:
136
+ x, y, window, f = downsample(x, y)
137
+ if as_loss:
138
+ feats0 = self.forward_once(x)
139
+ feats1 = self.forward_once(y)
140
+ else:
141
+ with torch.no_grad():
142
+ feats0 = self.forward_once(x)
143
+ feats1 = self.forward_once(y)
144
+ score = 0
145
+ layer_score=[]
146
+ # To see score of each layer, use debugging mode of Pycharm.
147
+
148
+ for k in range(len(self.chns)):
149
+ row_padding = round(feats0[k].size(2) / window) * window - feats0[k].size(2)
150
+ column_padding = round(feats0[k].size(3) / window) * window - feats0[k].size(3)
151
+
152
+ pad = nn.ZeroPad2d((column_padding, 0, 0, row_padding))
153
+ feats0_k = pad(feats0[k])
154
+ feats1_k = pad(feats1[k])
155
+
156
+ tmp = ws_distance(feats0_k, feats1_k, win=window)
157
+ layer_score.append(torch.log(tmp + 1))
158
+ score = score + tmp
159
+ score = score / (k+1)
160
+
161
+ # For optimization, the logrithm will not use.
162
+ if as_loss:
163
+ return score
164
+ # We find use log**2 output will lead to higher PLCC results, thus we provide two output strategies
165
+ # They will only affect PLCC of Quality Assessment Results.
166
+ elif f==1:
167
+ return torch.log(score + 1)
168
+ else:
169
+ return torch.log(score + 1)**2
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ DISTS-pytorch
2
+ torch
3
+ torchvision
4
+ pyiqa
5
+ POT