Spaces:
Sleeping
Sleeping
[Fix] PSNR computation
Browse files- app.py +9 -10
- factories.py +1 -1
app.py
CHANGED
|
@@ -52,6 +52,7 @@ def generate_imgs_from_user(image,
|
|
| 52 |
x = torch.cat((x, torch.zeros_like(x)), dim=1)
|
| 53 |
|
| 54 |
return generate_imgs(x, physics, use_gen, baseline, model, metrics)
|
|
|
|
| 55 |
def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
|
| 56 |
physics: PhysicsWithGenerator, use_gen: bool,
|
| 57 |
baseline: BaselineModel, model: EvalModel,
|
|
@@ -108,26 +109,24 @@ def generate_imgs(x: torch.Tensor,
|
|
| 108 |
if out_baseline.shape != out.shape:
|
| 109 |
out_baseline = out_baseline[..., w_1:w_2, h_1:h_2]
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
### Metrics
|
| 112 |
metrics_y = ""
|
| 113 |
metrics_out = ""
|
| 114 |
metrics_out_baseline = ""
|
| 115 |
for metric in metrics:
|
| 116 |
-
if y.shape == x.shape:
|
| 117 |
-
|
| 118 |
metrics_out += f"{metric.name} = {metric(out, x).item():.4f}" + "\n"
|
| 119 |
metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n"
|
| 120 |
metrics_out += f"Inference time = {ram_time:.3f}s"
|
| 121 |
metrics_out_baseline += f"Inference time = {dpir_time:.3f}s"
|
| 122 |
|
| 123 |
-
### Process y when y shape is different from x shape
|
| 124 |
-
if physics.name == "MRI":
|
| 125 |
-
y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4)
|
| 126 |
-
elif physics.name == "CT":
|
| 127 |
-
y_plot = physics.physics.A_adjoint(y)
|
| 128 |
-
else:
|
| 129 |
-
y_plot = y.clone()
|
| 130 |
-
|
| 131 |
### Processing images for plotting :
|
| 132 |
# - clip value outside of [0,1]
|
| 133 |
# - shape (1, C, H, W) -> (C, H, W)
|
|
|
|
| 52 |
x = torch.cat((x, torch.zeros_like(x)), dim=1)
|
| 53 |
|
| 54 |
return generate_imgs(x, physics, use_gen, baseline, model, metrics)
|
| 55 |
+
|
| 56 |
def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
|
| 57 |
physics: PhysicsWithGenerator, use_gen: bool,
|
| 58 |
baseline: BaselineModel, model: EvalModel,
|
|
|
|
| 109 |
if out_baseline.shape != out.shape:
|
| 110 |
out_baseline = out_baseline[..., w_1:w_2, h_1:h_2]
|
| 111 |
|
| 112 |
+
### Process y when y shape is different from x shape
|
| 113 |
+
if physics.name == 'MRI' and physics.name == 'CT':
|
| 114 |
+
y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4)
|
| 115 |
+
else:
|
| 116 |
+
y_plot = y.clone()
|
| 117 |
+
|
| 118 |
### Metrics
|
| 119 |
metrics_y = ""
|
| 120 |
metrics_out = ""
|
| 121 |
metrics_out_baseline = ""
|
| 122 |
for metric in metrics:
|
| 123 |
+
#if y.shape == x.shape:
|
| 124 |
+
metrics_y += f"{metric.name} = {metric(y_plot, x).item():.4f}" + "\n"
|
| 125 |
metrics_out += f"{metric.name} = {metric(out, x).item():.4f}" + "\n"
|
| 126 |
metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n"
|
| 127 |
metrics_out += f"Inference time = {ram_time:.3f}s"
|
| 128 |
metrics_out_baseline += f"Inference time = {dpir_time:.3f}s"
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
### Processing images for plotting :
|
| 131 |
# - clip value outside of [0,1]
|
| 132 |
# - shape (1, C, H, W) -> (C, H, W)
|
factories.py
CHANGED
|
@@ -140,7 +140,7 @@ class PhysicsWithGenerator(torch.nn.Module):
|
|
| 140 |
max_iter=10,
|
| 141 |
)
|
| 142 |
self.physics_generator = None
|
| 143 |
-
self.generator = SigmaGenerator(sigma_min=
|
| 144 |
self.saved_params = {"updatable_params": {"sigma": 1e-4},
|
| 145 |
"updatable_params_converter": {"sigma": float},
|
| 146 |
"fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2,
|
|
|
|
| 140 |
max_iter=10,
|
| 141 |
)
|
| 142 |
self.physics_generator = None
|
| 143 |
+
self.generator = SigmaGenerator(sigma_min=1e-5, sigma_max=1e-4, device=device_str)
|
| 144 |
self.saved_params = {"updatable_params": {"sigma": 1e-4},
|
| 145 |
"updatable_params_converter": {"sigma": float},
|
| 146 |
"fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2,
|