Spaces:
Running
on
Zero
Running
on
Zero
rizavelioglu
commited on
Commit
·
7c23ab5
0
Parent(s):
Initial commit
Browse files- .gitattributes +42 -0
- README.md +14 -0
- app.py +125 -0
- examples/01_1.jpg +3 -0
- examples/01_2.jpg +0 -0
- examples/02_1.jpg +3 -0
- examples/02_2.jpg +0 -0
- examples/03_1.jpg +3 -0
- examples/03_2.jpg +0 -0
- examples/04_1.jpg +0 -0
- examples/04_2.jpg +0 -0
- examples/05_1.jpg +3 -0
- examples/05_2.jpg +0 -0
- examples/06_1.jpg +0 -0
- examples/06_2.jpg +0 -0
- examples/07_1.jpg +3 -0
- examples/07_2.jpg +0 -0
- examples/08_1.jpg +0 -0
- examples/08_2.jpg +0 -0
- examples/09_1.jpg +3 -0
- examples/09_2.jpg +0 -0
- metrics/ADISTS.py +157 -0
- metrics/DeepDC.py +134 -0
- metrics/DeepWSD.py +169 -0
- requirements.txt +5 -0
.gitattributes
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
examples/01_1.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
examples/02_1.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
examples/03_1.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
examples/05_1.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
examples/06_1.jpg filter=lfs diff=lfs merge=lfs -text
|
41 |
+
examples/08_1.jpg filter=lfs diff=lfs merge=lfs -text
|
42 |
+
examples/10_1.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: FR IQA
|
3 |
+
emoji: 🌖
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.42.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
short_description: Compute similarity between two images using FR-IQA metrics
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
from DISTS_pytorch import DISTS
|
4 |
+
from torchvision.io import read_image
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms.v2 as transforms
|
7 |
+
import pyiqa
|
8 |
+
import spaces
|
9 |
+
from metrics.DeepDC import DeepDC
|
10 |
+
from metrics.DeepWSD import DeepWSD
|
11 |
+
from metrics.ADISTS import ADISTS
|
12 |
+
|
13 |
+
@spaces.GPU(duration=5)
|
14 |
+
class Evaluator:
|
15 |
+
def __init__(self, device):
|
16 |
+
self.device = device
|
17 |
+
self.metrics = {
|
18 |
+
"↓ MSE": torch.nn.functional.mse_loss,
|
19 |
+
"↓ L1": torch.nn.functional.l1_loss,
|
20 |
+
"↓ DISTS": DISTS().to(self.device),
|
21 |
+
"↓ LPIPS": pyiqa.create_metric("lpips", device=self.device),
|
22 |
+
"↑ PSNR": pyiqa.create_metric("psnr", device=self.device),
|
23 |
+
"↑ SSIM": pyiqa.create_metric("ssim", device=self.device),
|
24 |
+
"↑ MS-SSIM": pyiqa.create_metric("ms_ssim", device=self.device),
|
25 |
+
"↑ CW-SSIM": pyiqa.create_metric("cw_ssim", device=self.device),
|
26 |
+
"↑ FSIM": pyiqa.create_metric("fsim", device=self.device),
|
27 |
+
"↑ DeepDC": DeepDC().to(self.device),
|
28 |
+
"↑ DeepWSD": DeepWSD().to(self.device),
|
29 |
+
"↑ ADISTS": ADISTS().to(self.device),
|
30 |
+
}
|
31 |
+
self.transform = transforms.ToDtype(dtype=torch.float32, scale=True)
|
32 |
+
|
33 |
+
@torch.no_grad()
|
34 |
+
def evaluate(self, img_fname1, img_fname2):
|
35 |
+
img1 = self.transform(read_image(img_fname1)).unsqueeze(0).to(self.device)
|
36 |
+
img2 = self.transform(read_image(img_fname2)).unsqueeze(0).to(self.device)
|
37 |
+
|
38 |
+
# check images are the same size
|
39 |
+
if img1.shape != img2.shape:
|
40 |
+
raise gr.Error("Input images must have the same dimensions!")
|
41 |
+
|
42 |
+
return "\n".join(
|
43 |
+
f"{name:<10}: {float(metric(img1, img2).item()):3,.5f}"
|
44 |
+
for name, metric in self.metrics.items()
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
@spaces.GPU(duration=1)
|
49 |
+
def get_evaluator():
|
50 |
+
"""Returns a singleton Evaluator instance per worker/session."""
|
51 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
52 |
+
if not hasattr(get_evaluator, "evaluator"):
|
53 |
+
get_evaluator.evaluator = Evaluator(device)
|
54 |
+
return get_evaluator.evaluator
|
55 |
+
|
56 |
+
|
57 |
+
@spaces.GPU(duration=5)
|
58 |
+
def compute_similarity(img_fname1, img_fname2):
|
59 |
+
return get_evaluator().evaluate(img_fname1, img_fname2)
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
examples = [
|
64 |
+
["examples/01_1.jpg", "examples/01_1.jpg"], # Add an extra example for identical images
|
65 |
+
*[[f"examples/{i:02d}_1.jpg", f"examples/{i:02d}_2.jpg"] for i in range(1, 10)],
|
66 |
+
]
|
67 |
+
|
68 |
+
# Gradio UI
|
69 |
+
custom_css = """
|
70 |
+
.center-header {
|
71 |
+
display: flex;
|
72 |
+
align-items: center;
|
73 |
+
justify-content: center;
|
74 |
+
margin: 0 0 10px 0;
|
75 |
+
}
|
76 |
+
.monospace-text {
|
77 |
+
font-family: 'Courier New', Courier, monospace;
|
78 |
+
}
|
79 |
+
"""
|
80 |
+
with gr.Blocks(title="FR-IQA", css=custom_css) as demo:
|
81 |
+
gr.Markdown(f"""
|
82 |
+
<div class='center-header'><h1>FR-IQA</h1></div>
|
83 |
+
Upload two images to compute various Full-Reference IQA metrics for measuring similarity between them.<br>
|
84 |
+
<b>Note</b>: Images must be of the same size.
|
85 |
+
""")
|
86 |
+
|
87 |
+
with gr.Row():
|
88 |
+
with gr.Column(scale=2):
|
89 |
+
img_fname1 = gr.Image(type="filepath", label="Image#1", height=512, width=512)
|
90 |
+
with gr.Column(scale=2):
|
91 |
+
img_fname2 = gr.Image(type="filepath", label="Image#2", height=512, width=512)
|
92 |
+
with gr.Column(scale=1):
|
93 |
+
metrics_output = gr.Textbox(label="Metrics Output", lines=22, elem_classes="monospace-text", show_copy_button=True)
|
94 |
+
|
95 |
+
with gr.Row():
|
96 |
+
submit_btn = gr.Button("Compute!")
|
97 |
+
|
98 |
+
with gr.Row():
|
99 |
+
with gr.Column(scale=2):
|
100 |
+
gr.Examples(
|
101 |
+
examples=examples,
|
102 |
+
inputs=[img_fname1, img_fname2],
|
103 |
+
fn=compute_similarity,
|
104 |
+
outputs=metrics_output,
|
105 |
+
label="Example Image Pairs (all images are 1024×768)",
|
106 |
+
cache_examples=False,
|
107 |
+
examples_per_page=5
|
108 |
+
)
|
109 |
+
with gr.Column(scale=2):
|
110 |
+
gr.Markdown("""
|
111 |
+
<div class='center-header'><h3>Acknowledgements</h3></div>
|
112 |
+
|
113 |
+
- Example images are from the [TryOffDiff](https://rizavelioglu.github.io/tryoffdiff) paper, which are sampled from VITON-HD dataset.
|
114 |
+
- We use the [IQA-PyTorch](https://github.com/chaofengc/IQA-PyTorch) library for computing the metrics.
|
115 |
+
|
116 |
+
|
117 |
+
""")
|
118 |
+
|
119 |
+
submit_btn.click(
|
120 |
+
fn=compute_similarity,
|
121 |
+
inputs=[img_fname1, img_fname2],
|
122 |
+
outputs=[metrics_output]
|
123 |
+
)
|
124 |
+
|
125 |
+
demo.launch(share=False, ssr_mode=False)
|
examples/01_1.jpg
ADDED
![]() |
Git LFS Details
|
examples/01_2.jpg
ADDED
![]() |
examples/02_1.jpg
ADDED
![]() |
Git LFS Details
|
examples/02_2.jpg
ADDED
![]() |
examples/03_1.jpg
ADDED
![]() |
Git LFS Details
|
examples/03_2.jpg
ADDED
![]() |
examples/04_1.jpg
ADDED
![]() |
examples/04_2.jpg
ADDED
![]() |
examples/05_1.jpg
ADDED
![]() |
Git LFS Details
|
examples/05_2.jpg
ADDED
![]() |
examples/06_1.jpg
ADDED
![]() |
examples/06_2.jpg
ADDED
![]() |
examples/07_1.jpg
ADDED
![]() |
Git LFS Details
|
examples/07_2.jpg
ADDED
![]() |
examples/08_1.jpg
ADDED
![]() |
examples/08_2.jpg
ADDED
![]() |
examples/09_1.jpg
ADDED
![]() |
Git LFS Details
|
examples/09_2.jpg
ADDED
![]() |
metrics/ADISTS.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Taken from:
|
2 |
+
https://github.com/dingkeyan93/A-DISTS/blob/3d20592648625df2e451c9aba25bbaf3c7952ac8/A-DISTS.py
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import numpy as np
|
7 |
+
import torchvision
|
8 |
+
from torchvision import models,transforms
|
9 |
+
from torch.nn.functional import normalize
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import math
|
12 |
+
|
13 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
14 |
+
|
15 |
+
|
16 |
+
class Downsample(nn.Module):
|
17 |
+
def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):
|
18 |
+
super(Downsample, self).__init__()
|
19 |
+
self.padding = (filter_size - 2 )//2
|
20 |
+
self.stride = stride
|
21 |
+
self.channels = channels
|
22 |
+
a = np.hanning(filter_size)[1:-1]
|
23 |
+
g = torch.Tensor(a[:,None]*a[None,:])
|
24 |
+
g = g/torch.sum(g)
|
25 |
+
self.register_buffer('filter', g[None,None,:,:].repeat((self.channels,1,1,1)))
|
26 |
+
# print (g)
|
27 |
+
|
28 |
+
def forward(self, input):
|
29 |
+
input = input**2
|
30 |
+
out = F.conv2d(input, self.filter, stride=self.stride, padding=self.padding, groups=input.shape[1])
|
31 |
+
return (out+1e-12).sqrt()
|
32 |
+
|
33 |
+
class ADISTS(torch.nn.Module):
|
34 |
+
def __init__(self, window_size=21):
|
35 |
+
super(ADISTS, self).__init__()
|
36 |
+
vgg_pretrained_features = models.vgg16(pretrained=True).features
|
37 |
+
|
38 |
+
self.stage1 = torch.nn.Sequential()
|
39 |
+
self.stage2 = torch.nn.Sequential()
|
40 |
+
self.stage3 = torch.nn.Sequential()
|
41 |
+
self.stage4 = torch.nn.Sequential()
|
42 |
+
self.stage5 = torch.nn.Sequential()
|
43 |
+
|
44 |
+
for x in range(0,4):
|
45 |
+
self.stage1.add_module(str(x), vgg_pretrained_features[x])
|
46 |
+
self.stage2.add_module(str(4), Downsample(channels=64))
|
47 |
+
for x in range(5, 9):
|
48 |
+
self.stage2.add_module(str(x), vgg_pretrained_features[x])
|
49 |
+
self.stage3.add_module(str(9), Downsample(channels=128))
|
50 |
+
for x in range(10, 16):
|
51 |
+
self.stage3.add_module(str(x), vgg_pretrained_features[x])
|
52 |
+
self.stage4.add_module(str(16), Downsample(channels=256))
|
53 |
+
for x in range(17, 23):
|
54 |
+
self.stage4.add_module(str(x), vgg_pretrained_features[x])
|
55 |
+
self.stage5.add_module(str(23), Downsample(channels=512))
|
56 |
+
for x in range(24, 30):
|
57 |
+
self.stage5.add_module(str(x), vgg_pretrained_features[x])
|
58 |
+
|
59 |
+
for param in self.parameters():
|
60 |
+
param.requires_grad = False
|
61 |
+
|
62 |
+
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1,-1,1,1))
|
63 |
+
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1,-1,1,1))
|
64 |
+
|
65 |
+
self.chns = [3,64,128,256,512,512]
|
66 |
+
self.windows=nn.ParameterList()
|
67 |
+
self.window_size = window_size
|
68 |
+
for k in range(len(self.chns)):
|
69 |
+
self.windows.append(self.create_window(self.window_size, self.window_size/3, self.chns[k]))
|
70 |
+
|
71 |
+
|
72 |
+
def compute_prob(self, x, k):
|
73 |
+
|
74 |
+
theta = [[0, 0],
|
75 |
+
[1.0, 0.29],
|
76 |
+
[2.0, 0.52],
|
77 |
+
[2.95, 0.56],
|
78 |
+
[0.97, 0.25],
|
79 |
+
[0.21, 0.10]]
|
80 |
+
|
81 |
+
ps = 1/(1+torch.exp(-(x-theta[k][0])/theta[k][1]))
|
82 |
+
pt = 1 - ps
|
83 |
+
return ps, pt
|
84 |
+
|
85 |
+
|
86 |
+
def gaussian(self,window_size, sigma):
|
87 |
+
gauss = torch.Tensor([math.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
|
88 |
+
return gauss/gauss.sum()
|
89 |
+
|
90 |
+
def create_window(self,window_size, window_sigma, channel):
|
91 |
+
_1D_window = self.gaussian(window_size, window_sigma).unsqueeze(1)
|
92 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
93 |
+
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
94 |
+
return nn.Parameter(window,requires_grad=False)
|
95 |
+
|
96 |
+
def forward_once(self, x):
|
97 |
+
h = (x-self.mean)/self.std
|
98 |
+
h = self.stage1(h)
|
99 |
+
h_relu1_2 = h
|
100 |
+
h = self.stage2(h)
|
101 |
+
h_relu2_2 = h
|
102 |
+
h = self.stage3(h)
|
103 |
+
h_relu3_3 = h
|
104 |
+
h = self.stage4(h)
|
105 |
+
h_relu4_3 = h
|
106 |
+
if len(self.chns)==6:
|
107 |
+
h = self.stage5(h)
|
108 |
+
h_relu5_3 = h
|
109 |
+
outs = [x,h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]
|
110 |
+
else:
|
111 |
+
outs = [x,h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3]
|
112 |
+
return outs
|
113 |
+
|
114 |
+
def forward(self, x, y, as_loss=False):
|
115 |
+
assert x.shape == y.shape
|
116 |
+
if as_loss:
|
117 |
+
feats_x = self.forward_once(x)
|
118 |
+
feats_y = self.forward_once(y)
|
119 |
+
else:
|
120 |
+
with torch.no_grad():
|
121 |
+
feats_x = self.forward_once(x)
|
122 |
+
feats_y = self.forward_once(y)
|
123 |
+
|
124 |
+
D = 0
|
125 |
+
c1 = 1e-6
|
126 |
+
c2 = 1e-6
|
127 |
+
pad = nn.ReflectionPad2d(0)
|
128 |
+
for k in range(len(self.chns)-1,-1,-1):
|
129 |
+
try:
|
130 |
+
x_mean = F.conv2d(pad(feats_x[k]), self.windows[k], stride =1, padding = 0, groups = self.chns[k])
|
131 |
+
y_mean = F.conv2d(pad(feats_y[k]), self.windows[k], stride =1, padding = 0, groups = self.chns[k])
|
132 |
+
x_var = F.conv2d(pad(feats_x[k]**2), self.windows[k], stride =1, padding = 0, groups = self.chns[k]) - x_mean**2
|
133 |
+
y_var = F.conv2d(pad(feats_y[k]**2), self.windows[k], stride =1, padding = 0, groups = self.chns[k]) - y_mean**2
|
134 |
+
xy_cov = F.conv2d(pad(feats_x[k]*feats_y[k]), self.windows[k], stride =1, padding = 0, groups = self.chns[k]) - x_mean*y_mean
|
135 |
+
except:
|
136 |
+
x_mean = feats_x[k].mean([2,3], keepdim=True)
|
137 |
+
y_mean = feats_y[k].mean([2,3], keepdim=True)
|
138 |
+
x_var = ((feats_x[k]-x_mean)**2).mean([2,3], keepdim=True)
|
139 |
+
y_var = ((feats_y[k]-y_mean)**2).mean([2,3], keepdim=True)
|
140 |
+
xy_cov = (feats_x[k]*feats_y[k]).mean([2,3],keepdim=True) - x_mean*y_mean
|
141 |
+
|
142 |
+
T = (2*x_mean*y_mean+c1)/(x_mean**2+y_mean**2+c1)
|
143 |
+
S = (2*xy_cov+c2)/(x_var+y_var+c2)
|
144 |
+
|
145 |
+
if k>0:
|
146 |
+
ratio = torch.mean(x_var/(x_mean+1e-12),dim=1,keepdim=True)
|
147 |
+
ps, pt = self.compute_prob(ratio,k)
|
148 |
+
|
149 |
+
D_map = pt*T+ps*S
|
150 |
+
# D = D + D_map.mean([2,3]).mean(1)/len(self.chns)
|
151 |
+
D = D + D_map.mean([2,3]).sum(1)/sum(self.chns)
|
152 |
+
|
153 |
+
if as_loss:
|
154 |
+
return 1-D.mean()
|
155 |
+
else:
|
156 |
+
return 1-D
|
157 |
+
|
metrics/DeepDC.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Taken from: https://github.com/h4nwei/DeepDC
|
2 |
+
"""
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from typing import Optional
|
6 |
+
from collections import OrderedDict
|
7 |
+
import numpy as np
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import math
|
10 |
+
import torchvision
|
11 |
+
from torchvision import models, transforms
|
12 |
+
|
13 |
+
|
14 |
+
names = {'vgg19': ['image', 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
|
15 |
+
'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
16 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
|
17 |
+
'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
|
18 |
+
'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2',
|
19 |
+
'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
|
20 |
+
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
|
21 |
+
'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'],}
|
22 |
+
|
23 |
+
class MultiVGGFeaturesExtractor(nn.Module):
|
24 |
+
def __init__(self, target_features=('conv1_2', 'conv2_2', 'conv3_4', 'conv4_4', 'conv5_4'), use_input_norm=False, requires_grad=False): # ALL FALSE is the best for COS_Similarity; Correlation: use_norm = True
|
25 |
+
super(MultiVGGFeaturesExtractor, self).__init__()
|
26 |
+
self.use_input_norm = use_input_norm
|
27 |
+
self.target_features = target_features
|
28 |
+
|
29 |
+
|
30 |
+
model = torchvision.models.vgg19(pretrained=True)
|
31 |
+
names_key = 'vgg19'
|
32 |
+
|
33 |
+
if self.use_input_norm:
|
34 |
+
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
35 |
+
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
36 |
+
self.register_buffer('mean', mean)
|
37 |
+
self.register_buffer('std', std)
|
38 |
+
|
39 |
+
self.target_indexes = [names[names_key].index(k) - 1 for k in self.target_features]
|
40 |
+
self.features = nn.Sequential(*list(model.features.children())[:(max(self.target_indexes) + 1)])
|
41 |
+
|
42 |
+
if not requires_grad:
|
43 |
+
for k, v in self.features.named_parameters():
|
44 |
+
v.requires_grad = False
|
45 |
+
self.features.eval()
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
# assume input range is [0, 1]
|
49 |
+
if self.use_input_norm:
|
50 |
+
x = (x - self.mean) / self.std
|
51 |
+
|
52 |
+
y = OrderedDict()
|
53 |
+
if 'image' in self.target_features:
|
54 |
+
y.update({"image": x})
|
55 |
+
for key, layer in self.features._modules.items():
|
56 |
+
x = layer(x)
|
57 |
+
# x = self._normalize_tensor(x)
|
58 |
+
if int(key) in self.target_indexes:
|
59 |
+
y.update({self.target_features[self.target_indexes.index(int(key))]: x})
|
60 |
+
return y
|
61 |
+
|
62 |
+
def _normalize_tensor(sefl, in_feat, eps=1e-10):
|
63 |
+
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True))
|
64 |
+
return in_feat / (norm_factor + eps)
|
65 |
+
|
66 |
+
|
67 |
+
class DeepDC(nn.Module):
|
68 |
+
def __init__(self, features_to_compute=('conv1_2', 'conv2_2', 'conv3_4', 'conv4_4', 'conv5_4')):
|
69 |
+
super(DeepDC, self).__init__()
|
70 |
+
self.MSE = torch.nn.MSELoss()
|
71 |
+
self.features_extractor = MultiVGGFeaturesExtractor(target_features=features_to_compute).eval()
|
72 |
+
|
73 |
+
def forward(self, x, y):
|
74 |
+
r"""Compute IQA using DeepDC model.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
- x: An input tensor with (N, C, H, W) shape. RGB channel order for colour images.
|
78 |
+
- y: An reference tensor with (N, C, H, W) shape. RGB channel order for colour images.
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
Value of DeepDC model.
|
82 |
+
|
83 |
+
"""
|
84 |
+
targets, inputs = x, y
|
85 |
+
inputs_fea = self.features_extractor(inputs)
|
86 |
+
|
87 |
+
with torch.no_grad():
|
88 |
+
targets_fea =self.features_extractor(targets)
|
89 |
+
|
90 |
+
dc_scores = []
|
91 |
+
|
92 |
+
for _, key in enumerate(inputs_fea.keys()):
|
93 |
+
inputs_dcdm = self._DCDM(inputs_fea[key])
|
94 |
+
targets_dcdm = self._DCDM(targets_fea[key])
|
95 |
+
dc_scores.append(self.Distance_Correlation(inputs_dcdm, targets_dcdm))
|
96 |
+
|
97 |
+
dc_scores = torch.stack(dc_scores, dim=1)
|
98 |
+
|
99 |
+
score = 1 - dc_scores.mean(dim=1, keepdim=True)
|
100 |
+
|
101 |
+
return score
|
102 |
+
|
103 |
+
|
104 |
+
# double-centered distance matrix (dcdm)
|
105 |
+
def _DCDM(self, x):
|
106 |
+
if len(x.shape)==4:
|
107 |
+
batchSize, dim, h, w = x.data.shape
|
108 |
+
M = h * w
|
109 |
+
elif len(x.shape)==3:
|
110 |
+
batchSize, M, dim = x.data.shape
|
111 |
+
x = x.reshape(batchSize, dim, M)
|
112 |
+
t = torch.log((1. / (torch.tensor(dim) * torch.tensor(dim))) )
|
113 |
+
|
114 |
+
I = torch.eye(dim, dim, device=x.device).view(1, dim, dim).repeat(batchSize, 1, 1).type(x.dtype)
|
115 |
+
I_M = torch.ones(batchSize, dim, dim, device=x.device).type(x.dtype)
|
116 |
+
x_pow2 = x.bmm(x.transpose(1, 2))
|
117 |
+
dcov = I_M.bmm(x_pow2 * I) + (x_pow2 * I).bmm(I_M) - 2 * x_pow2
|
118 |
+
|
119 |
+
dcov = torch.clamp(dcov, min=0.0)
|
120 |
+
dcov = torch.exp(t)* dcov
|
121 |
+
dcov = torch.sqrt(dcov + 1e-5)
|
122 |
+
dcdm = dcov - 1. / dim * dcov.bmm(I_M) - 1. / dim * I_M.bmm(dcov) + 1. / (dim * dim) * I_M.bmm(dcov).bmm(I_M)
|
123 |
+
|
124 |
+
return dcdm
|
125 |
+
|
126 |
+
|
127 |
+
def Distance_Correlation(self, matrix_A, matrix_B):
|
128 |
+
|
129 |
+
Gamma_XY = torch.sum(matrix_A * matrix_B, dim=[1,2])
|
130 |
+
Gamma_XX = torch.sum(matrix_A * matrix_A, dim=[1,2])
|
131 |
+
Gamma_YY = torch.sum(matrix_B * matrix_B, dim=[1,2])
|
132 |
+
c = 1e-6
|
133 |
+
correlation_r = (Gamma_XY + c) / (torch.sqrt(Gamma_XX * Gamma_YY) + c)
|
134 |
+
return correlation_r
|
metrics/DeepWSD.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Taken from: https://github.com/Buka-Xing/DeepWSD/blob/main/utils.py
|
2 |
+
"""
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import torch
|
7 |
+
from torchvision import models,transforms
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import inspect
|
11 |
+
from ot.lp import wasserstein_1d
|
12 |
+
|
13 |
+
|
14 |
+
# Process input of VGG16 to make it close to 256
|
15 |
+
def downsample(img1, img2, maxSize = 256):
|
16 |
+
_,channels,H,W = img1.shape
|
17 |
+
f = int(max(1,np.round(max(H,W)/maxSize)))
|
18 |
+
|
19 |
+
aveKernel = (torch.ones(channels,1,f,f)/f**2).to(img1.device)
|
20 |
+
img1 = F.conv2d(img1, aveKernel, stride=f, padding = 0, groups = channels)
|
21 |
+
img2 = F.conv2d(img2, aveKernel, stride=f, padding = 0, groups = channels)
|
22 |
+
# For an extremely Large image, the larger window will use to increase the receptive field.
|
23 |
+
if f >= 5:
|
24 |
+
win = 16
|
25 |
+
else:
|
26 |
+
win = 4
|
27 |
+
return img1, img2, win, f
|
28 |
+
|
29 |
+
|
30 |
+
# Use L2pooling for VGG16 networks.
|
31 |
+
# Original Maxpooling will generate distortions in color channels during optimization.
|
32 |
+
class L2pooling(nn.Module):
|
33 |
+
def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):
|
34 |
+
super(L2pooling, self).__init__()
|
35 |
+
self.padding = (filter_size - 2 )//2
|
36 |
+
self.stride = stride
|
37 |
+
self.channels = channels
|
38 |
+
a = np.hanning(filter_size)[1:-1]
|
39 |
+
|
40 |
+
g = torch.Tensor(a[:,None]*a[None,:])
|
41 |
+
g = g/torch.sum(g)
|
42 |
+
self.register_buffer('filter', g[None,None,:,:].repeat((self.channels,1,1,1)))
|
43 |
+
|
44 |
+
def forward(self, input):
|
45 |
+
input = input**2
|
46 |
+
out = F.conv2d(input, self.filter, stride=self.stride, padding=self.padding, groups=input.shape[1])
|
47 |
+
return (out+1e-12).sqrt()
|
48 |
+
|
49 |
+
|
50 |
+
def ws_distance(X,Y,P=2,win=4):
|
51 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
52 |
+
|
53 |
+
chn_num = X.shape[1]
|
54 |
+
X_sum = X.sum().sum()
|
55 |
+
Y_sum = Y.sum().sum()
|
56 |
+
|
57 |
+
X_patch = torch.reshape(X,[win,win,chn_num,-1])
|
58 |
+
Y_patch = torch.reshape(Y,[win,win,chn_num,-1])
|
59 |
+
patch_num = (X.shape[2]//win) * (X.shape[3]//win)
|
60 |
+
|
61 |
+
X_1D = torch.reshape(X_patch,[-1,chn_num*patch_num])
|
62 |
+
Y_1D = torch.reshape(Y_patch,[-1,chn_num*patch_num])
|
63 |
+
|
64 |
+
X_1D_pdf = X_1D / (X_sum + 1e-6)
|
65 |
+
Y_1D_pdf = Y_1D / (Y_sum + 1e-6)
|
66 |
+
|
67 |
+
interval = np.arange(0, X_1D.shape[0], 1)
|
68 |
+
all_samples = torch.from_numpy(interval).to(device).repeat([patch_num*chn_num,1]).t()
|
69 |
+
|
70 |
+
X_pdf = X_1D * X_1D_pdf
|
71 |
+
Y_pdf = Y_1D * Y_1D_pdf
|
72 |
+
|
73 |
+
wsd = wasserstein_1d(all_samples, all_samples, X_pdf, Y_pdf, P)
|
74 |
+
|
75 |
+
L2 = ((X_1D - Y_1D) ** 2).sum(dim=0)
|
76 |
+
w = (1 / ( torch.sqrt(torch.exp( (- 1/(wsd+10) ))) * (wsd+10)**2))
|
77 |
+
|
78 |
+
final = wsd + L2 * w
|
79 |
+
# final = wsd
|
80 |
+
|
81 |
+
return final.sum()
|
82 |
+
|
83 |
+
|
84 |
+
class DeepWSD(torch.nn.Module):
|
85 |
+
|
86 |
+
def __init__(self, channels=3, load_weights=True):
|
87 |
+
assert channels == 3
|
88 |
+
super(DeepWSD, self).__init__()
|
89 |
+
self.window = 4
|
90 |
+
|
91 |
+
vgg_pretrained_features = models.vgg16(pretrained=True).features
|
92 |
+
self.stage1 = torch.nn.Sequential()
|
93 |
+
self.stage2 = torch.nn.Sequential()
|
94 |
+
self.stage3 = torch.nn.Sequential()
|
95 |
+
self.stage4 = torch.nn.Sequential()
|
96 |
+
self.stage5 = torch.nn.Sequential()
|
97 |
+
|
98 |
+
# Rewrite the output layer of every block in the VGG network: maxpool->l2pool
|
99 |
+
for x in range(0, 4):
|
100 |
+
self.stage1.add_module(str(x), vgg_pretrained_features[x])
|
101 |
+
self.stage2.add_module(str(4), L2pooling(channels=64))
|
102 |
+
for x in range(5, 9):
|
103 |
+
self.stage2.add_module(str(x), vgg_pretrained_features[x])
|
104 |
+
self.stage3.add_module(str(9), L2pooling(channels=128))
|
105 |
+
for x in range(10, 16):
|
106 |
+
self.stage3.add_module(str(x), vgg_pretrained_features[x])
|
107 |
+
self.stage4.add_module(str(16), L2pooling(channels=256))
|
108 |
+
for x in range(17, 23):
|
109 |
+
self.stage4.add_module(str(x), vgg_pretrained_features[x])
|
110 |
+
self.stage5.add_module(str(23), L2pooling(channels=512))
|
111 |
+
for x in range(24, 30):
|
112 |
+
self.stage5.add_module(str(x), vgg_pretrained_features[x])
|
113 |
+
|
114 |
+
for param in self.parameters():
|
115 |
+
param.requires_grad = False
|
116 |
+
|
117 |
+
self.chns = [3, 64, 128, 256, 512, 512]
|
118 |
+
|
119 |
+
def forward_once(self, x):
|
120 |
+
h = x
|
121 |
+
h = self.stage1(h)
|
122 |
+
h_relu1_2 = h
|
123 |
+
h = self.stage2(h)
|
124 |
+
h_relu2_2 = h
|
125 |
+
h = self.stage3(h)
|
126 |
+
h_relu3_3 = h
|
127 |
+
h = self.stage4(h)
|
128 |
+
h_relu4_3 = h
|
129 |
+
h = self.stage5(h)
|
130 |
+
h_relu5_3 = h
|
131 |
+
return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]
|
132 |
+
|
133 |
+
def forward(self, x, y, as_loss=False, resize=True):
|
134 |
+
assert x.shape == y.shape
|
135 |
+
if resize:
|
136 |
+
x, y, window, f = downsample(x, y)
|
137 |
+
if as_loss:
|
138 |
+
feats0 = self.forward_once(x)
|
139 |
+
feats1 = self.forward_once(y)
|
140 |
+
else:
|
141 |
+
with torch.no_grad():
|
142 |
+
feats0 = self.forward_once(x)
|
143 |
+
feats1 = self.forward_once(y)
|
144 |
+
score = 0
|
145 |
+
layer_score=[]
|
146 |
+
# To see score of each layer, use debugging mode of Pycharm.
|
147 |
+
|
148 |
+
for k in range(len(self.chns)):
|
149 |
+
row_padding = round(feats0[k].size(2) / window) * window - feats0[k].size(2)
|
150 |
+
column_padding = round(feats0[k].size(3) / window) * window - feats0[k].size(3)
|
151 |
+
|
152 |
+
pad = nn.ZeroPad2d((column_padding, 0, 0, row_padding))
|
153 |
+
feats0_k = pad(feats0[k])
|
154 |
+
feats1_k = pad(feats1[k])
|
155 |
+
|
156 |
+
tmp = ws_distance(feats0_k, feats1_k, win=window)
|
157 |
+
layer_score.append(torch.log(tmp + 1))
|
158 |
+
score = score + tmp
|
159 |
+
score = score / (k+1)
|
160 |
+
|
161 |
+
# For optimization, the logrithm will not use.
|
162 |
+
if as_loss:
|
163 |
+
return score
|
164 |
+
# We find use log**2 output will lead to higher PLCC results, thus we provide two output strategies
|
165 |
+
# They will only affect PLCC of Quality Assessment Results.
|
166 |
+
elif f==1:
|
167 |
+
return torch.log(score + 1)
|
168 |
+
else:
|
169 |
+
return torch.log(score + 1)**2
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DISTS-pytorch
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
pyiqa
|
5 |
+
POT
|