Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from DISTS_pytorch import DISTS | |
from torchvision.io import read_image | |
import torch | |
import torchvision.transforms.v2 as transforms | |
import spaces | |
from metrics.DeepDC import DeepDC | |
from metrics.DeepWSD import DeepWSD | |
from metrics.ADISTS import ADISTS | |
from dreamsim import dreamsim | |
# pyiqa requires older version of packages, causing dependency issues during install. Therefore, we install it here. | |
# Specifically, it requires transformers=4.37.2. | |
try: | |
import pyiqa | |
except ImportError: | |
print("pyiqa not found. Installing...") | |
import subprocess | |
import sys | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "pyiqa==0.1.14.1", "--no-deps"]) | |
import pyiqa | |
# Download models once at startup | |
_, _ = dreamsim(pretrained=True, device="cpu") | |
class Evaluator: | |
def __init__(self, device): | |
self.device = device | |
self.transform = transforms.ToDtype(dtype=torch.float32, scale=True) | |
self.metrics = self._init_metrics() | |
def _init_metrics(self): | |
return { | |
"β MSE": torch.nn.functional.mse_loss, | |
"β L1": torch.nn.functional.l1_loss, | |
"β DISTS": DISTS().to(self.device), | |
"β ADISTS": ADISTS().to(self.device), | |
"β DeepDC": DeepDC().to(self.device), | |
"β DeepWSD": DeepWSD().to(self.device), | |
"β LPIPS": pyiqa.create_metric("lpips", device=self.device), | |
"β DreamSim": dreamsim(pretrained=True, device=self.device)[0], | |
"β PSNR": pyiqa.create_metric("psnr", device=self.device), | |
"β SSIM": pyiqa.create_metric("ssim", device=self.device), | |
"β MS-SSIM": pyiqa.create_metric("ms_ssim", device=self.device), | |
"β CW-SSIM": pyiqa.create_metric("cw_ssim", device=self.device), | |
"β FSIM": pyiqa.create_metric("fsim", device=self.device), | |
} | |
def evaluate(self, img_fname1, img_fname2): | |
img1 = self.transform(read_image(img_fname1)).unsqueeze(0).to(self.device) | |
img2 = self.transform(read_image(img_fname2)).unsqueeze(0).to(self.device) | |
# check images are the same size | |
if img1.shape != img2.shape: | |
return "Input images must have the same dimensions!" | |
return "\n".join( | |
f"{name:<10}: {float(metric(img1, img2).item()):3,.5f}" | |
for name, metric in self.metrics.items() | |
) | |
def get_evaluator(): | |
"""Returns a singleton Evaluator instance per worker/session.""" | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
if not hasattr(get_evaluator, "evaluator"): | |
get_evaluator.evaluator = Evaluator(device) | |
return get_evaluator.evaluator | |
def compute_similarity(img1_path, img2_path): | |
"""Main function for Gradio interface.""" | |
if not img1_path or not img2_path: | |
return "Please upload both images!" | |
return get_evaluator().evaluate(img1_path, img2_path) | |
def create_interface(): | |
examples = [ | |
["examples/01_1.jpg", "examples/01_1.jpg"], # Add an extra example for identical images | |
["examples/01_1.jpg", "examples/noise.jpg"], | |
*[[f"examples/{i:02d}_1.jpg", f"examples/{i:02d}_2.jpg"] for i in range(1, 10)], | |
] | |
# Custom CSS | |
css = """ | |
.center-header { | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
margin: 0 0 10px 0; | |
} | |
.monospace-text { | |
font-family: 'Courier New', Courier, monospace; | |
} | |
.metrics-table { | |
width: 100%; | |
border-collapse: collapse; | |
} | |
.metrics-table td { | |
padding: 10px; | |
vertical-align: top; | |
} | |
""" | |
# Add UI elements | |
pyiqa_url = "https://github.com/chaofengc/IQA-PyTorch" | |
with gr.Blocks(title="FR-IQA", css=css) as demo: | |
gr.Markdown(f""" | |
<div class='center-header'><h1>Full-Reference Image Quality Assessment</h1></div> | |
Upload two images to compute various similarity metrics.<br> | |
**Note**: Images must have identical dimensions. Code will run much faster locally: due to ZeroGPU setup, metrics are re-initialized on every run.. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
img_fname1 = gr.Image(type="filepath", label="Image#1", height=512, width=512) | |
with gr.Column(scale=2): | |
img_fname2 = gr.Image(type="filepath", label="Image#2", height=512, width=512) | |
with gr.Column(scale=1): | |
metrics_output = gr.Textbox(label="Metrics Output", lines=22, elem_classes="monospace-text", show_copy_button=True) | |
with gr.Row(): | |
submit_btn = gr.Button("Compute Metrics", variant="primary") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
gr.Examples( | |
examples=examples, | |
inputs=[img_fname1, img_fname2], | |
fn=compute_similarity, | |
outputs=metrics_output, | |
label="Example Pairs (all are 1024Γ768)", | |
cache_examples=True, | |
cache_mode="lazy", | |
examples_per_page=6 | |
) | |
with gr.Column(scale=2): | |
gr.Markdown(f""" | |
<div class='center-header'><h3>Acknowledgements</h3></div> | |
- Example images from [TryOffDiff](https://rizavelioglu.github.io/tryoffdiff) paper, which are sampled from VITON-HD dataset. | |
- Metrics (*score range is only rough estimation, actual score range may vary*): | |
<table class="metrics-table"> | |
<tr> | |
<th>Metric</th> | |
<th>Score Range</th> | |
<th>Lower is better?</th> | |
<th>Source</th> | |
</tr> | |
<tr> | |
<td>MSE</td> | |
<td>[0, β)</td> | |
<td>Yes</td> | |
<td><a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.MSELoss.html">torch</a></td> | |
</tr> | |
<tr> | |
<td>L1</td> | |
<td>[0, β)</td> | |
<td>Yes</td> | |
<td><a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.L1Loss.html">torch</a></td> | |
</tr> | |
<tr> | |
<td>DISTS</td> | |
<td>[0, 1]</td> | |
<td>Yes</td> | |
<td><a href="https://github.com/dingkeyan93/DISTS">official</a></td> | |
</tr> | |
<tr> | |
<td>ADISTS</td> | |
<td>~[0, 1]</td> | |
<td>Yes</td> | |
<td><a href="https://github.com/dingkeyan93/A-DISTS">official</a></td> | |
</tr> | |
<tr> | |
<td>DeepDC</td> | |
<td>[0, 1]</td> | |
<td>Yes</td> | |
<td><a href="https://github.com/h4nwei/DeepDC">official</a></td> | |
</tr> | |
<tr> | |
<td>DeepWSD</td> | |
<td>[0, β)</td> | |
<td>Yes</td> | |
<td><a href="https://github.com/Buka-Xing/DeepWSD">official</a></td> | |
</tr> | |
<tr> | |
<td>LPIPS</td> | |
<td>[0, 1]</td> | |
<td>Yes</td> | |
<td><a href="{pyiqa_url}">pyiqa</a></td> | |
</tr> | |
<tr> | |
<td>DreamSim</td> | |
<td>[0, 1]</td> | |
<td>Yes</td> | |
<td><a href="https://github.com/ssundaram21/dreamsim">official</a></td> | |
</tr> | |
<tr> | |
<td>PSNR</td> | |
<td>[0, β)</td> | |
<td>No</td> | |
<td><a href="{pyiqa_url}">pyiqa</a></td> | |
</tr> | |
<tr> | |
<td>SSIM</td> | |
<td>[0, 1]</td> | |
<td>No</td> | |
<td><a href="{pyiqa_url}">pyiqa</a></td> | |
</tr> | |
<tr> | |
<td>MS-SSIM</td> | |
<td>[0, 1]</td> | |
<td>No</td> | |
<td><a href="{pyiqa_url}">pyiqa</a></td> | |
</tr> | |
<tr> | |
<td>CW-SSIM</td> | |
<td>[0, 1]</td> | |
<td>No</td> | |
<td><a href="{pyiqa_url}">pyiqa</a></td> | |
</tr> | |
<tr> | |
<td>FSIM</td> | |
<td>[0, 1]</td> | |
<td>No</td> | |
<td><a href="{pyiqa_url}">pyiqa</a></td> | |
</tr> | |
</table> | |
""") | |
submit_btn.click( | |
fn=compute_similarity, | |
inputs=[img_fname1, img_fname2], | |
outputs=[metrics_output] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch(share=False, ssr_mode=False) | |