|
import os |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from matplotlib import cm |
|
import gradio as gr |
|
from PIL import Image |
|
from pathlib import Path |
|
|
|
|
|
from inference import BathymetrySuperResolution |
|
|
|
|
|
CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "checkpoints") |
|
MODEL_CHECKPOINT = os.path.join(CHECKPOINT_DIR, "calibrated.pth") |
|
CONFIG_PATH = os.environ.get("CONFIG_PATH", "config.json") |
|
|
|
|
|
try: |
|
model = BathymetrySuperResolution( |
|
model_type="vqvae", |
|
checkpoint_path=MODEL_CHECKPOINT, |
|
config_path=CONFIG_PATH |
|
) |
|
model_loaded = True |
|
except Exception as e: |
|
print(f"Error loading model: {str(e)}") |
|
model = None |
|
model_loaded = False |
|
|
|
def process_upload(file, confidence_level, block_size, model_type): |
|
"""Process uploaded bathymetry file""" |
|
if file is None: |
|
return None, "Please upload a file." |
|
|
|
try: |
|
|
|
if not model_loaded: |
|
return None, "Model not loaded. Please check server logs." |
|
|
|
|
|
if file.name.endswith('.npy'): |
|
data = np.load(file.name) |
|
else: |
|
|
|
img = Image.open(file.name).convert('L') |
|
data = np.array(img) |
|
|
|
|
|
if model.config['model_type'] != model_type or model.config['model_config']['block_size'] != block_size: |
|
|
|
pass |
|
|
|
|
|
prediction, lower_bound, upper_bound = model.predict( |
|
data, |
|
with_uncertainty=True, |
|
confidence_level=confidence_level/100.0 |
|
) |
|
|
|
|
|
uncertainty_width = model.get_uncertainty_width(lower_bound, upper_bound) |
|
|
|
|
|
fig = plt.figure(figsize=(15, 10)) |
|
|
|
|
|
ax1 = fig.add_subplot(231) |
|
if data.shape != (32, 32): |
|
from scipy.ndimage import zoom |
|
zoom_factor = 32 / max(data.shape) |
|
input_data = zoom(data, zoom_factor) |
|
else: |
|
input_data = data |
|
im1 = ax1.imshow(input_data, cmap=cm.viridis) |
|
ax1.set_title("Input (32x32)") |
|
plt.colorbar(im1, ax=ax1) |
|
|
|
|
|
ax2 = fig.add_subplot(232) |
|
im2 = ax2.imshow(prediction[0, 0], cmap=cm.viridis) |
|
ax2.set_title("Super-Resolution (64x64)") |
|
plt.colorbar(im2, ax=ax2) |
|
|
|
|
|
ax3 = fig.add_subplot(233) |
|
im3 = ax3.imshow(lower_bound[0, 0], cmap=cm.viridis) |
|
ax3.set_title(f"Lower Bound ({confidence_level}% CI)") |
|
plt.colorbar(im3, ax=ax3) |
|
|
|
|
|
ax4 = fig.add_subplot(234) |
|
im4 = ax4.imshow(upper_bound[0, 0], cmap=cm.viridis) |
|
ax4.set_title(f"Upper Bound ({confidence_level}% CI)") |
|
plt.colorbar(im4, ax=ax4) |
|
|
|
|
|
ax5 = fig.add_subplot(235) |
|
uncertainty_map = upper_bound[0, 0] - lower_bound[0, 0] |
|
im5 = ax5.imshow(uncertainty_map, cmap='hot') |
|
ax5.set_title("Uncertainty Width") |
|
plt.colorbar(im5, ax=ax5) |
|
|
|
|
|
ax6 = fig.add_subplot(236, projection='3d') |
|
x = np.arange(0, prediction.shape[2]) |
|
y = np.arange(0, prediction.shape[3]) |
|
X, Y = np.meshgrid(x, y) |
|
surf = ax6.plot_surface(X, Y, prediction[0, 0], cmap=cm.viridis, |
|
linewidth=0, antialiased=True) |
|
ax6.set_title("3D Bathymetry") |
|
|
|
plt.tight_layout() |
|
|
|
|
|
summary = f""" |
|
**Super-Resolution Results:** |
|
- **Model Type**: {model_type.upper()} |
|
- **Block Size**: {block_size}×{block_size} |
|
- **Confidence Level**: {confidence_level}% |
|
- **Average Uncertainty Width**: {uncertainty_width:.4f} |
|
- **Input Shape**: {data.shape} |
|
- **Output Shape**: {prediction.shape[2:]} |
|
""" |
|
|
|
return fig, summary |
|
|
|
except Exception as e: |
|
import traceback |
|
traceback.print_exc() |
|
return None, f"Error processing file: {str(e)}" |
|
|
|
def create_sample_data(): |
|
"""Create a sample bathymetry data file for demonstration""" |
|
|
|
x = np.linspace(0, 1, 32) |
|
y = np.linspace(0, 1, 32) |
|
xx, yy = np.meshgrid(x, y) |
|
|
|
|
|
z = -4000 + 500 * np.sin(10 * xx) * np.cos(8 * yy) + 300 * np.exp(-((xx-0.3)**2 + (yy-0.7)**2)/0.1) |
|
|
|
|
|
sample_dir = Path("samples") |
|
sample_dir.mkdir(exist_ok=True) |
|
sample_path = sample_dir / "sample.npy" |
|
np.save(sample_path, z) |
|
|
|
return str(sample_path) |
|
|
|
|
|
with gr.Blocks(title="Bathymetry Super-Resolution") as demo: |
|
gr.Markdown(""" |
|
# Bathymetry Super-Resolution with Uncertainty Quantification |
|
|
|
This application demonstrates super-resolution of ocean floor (bathymetry) data with uncertainty estimates. |
|
Upload a bathymetry file (NPY or image) to see the enhanced resolution output with confidence intervals. |
|
|
|
The model uses a **Vector Quantized Variational Autoencoder (VQ-VAE)** with **block-based uncertainty quantification**. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_file = gr.File(label="Upload Bathymetry File (.npy or image)") |
|
|
|
with gr.Row(): |
|
confidence_level = gr.Slider( |
|
minimum=80, maximum=99, value=95, step=1, |
|
label="Confidence Level (%)" |
|
) |
|
|
|
block_size = gr.Dropdown( |
|
choices=[1, 2, 4, 8, 64], value=4, |
|
label="Block Size" |
|
) |
|
|
|
model_type = gr.Dropdown( |
|
choices=["vqvae", "srcnn", "gan"], value="vqvae", |
|
label="Model Type" |
|
) |
|
|
|
with gr.Row(): |
|
process_btn = gr.Button("Generate Super-Resolution") |
|
sample_btn = gr.Button("Load Sample Data") |
|
|
|
with gr.Column(): |
|
output_plots = gr.Plot(label="Super-Resolution Results") |
|
output_text = gr.Markdown(label="Summary") |
|
|
|
|
|
process_btn.click( |
|
fn=process_upload, |
|
inputs=[input_file, confidence_level, block_size, model_type], |
|
outputs=[output_plots, output_text] |
|
) |
|
|
|
|
|
sample_btn.click( |
|
fn=lambda: gr.update(value=create_sample_data()), |
|
inputs=None, |
|
outputs=input_file |
|
) |
|
|
|
gr.Markdown(""" |
|
## About This Model |
|
|
|
This model enhances the resolution of bathymetric data from 32×32 to 64×64 while providing uncertainty estimates. |
|
It was trained on bathymetry data from multiple ocean regions including the Eastern Pacific Basin, Western Pacific Region, and Indian Ocean Basin. |
|
|
|
The uncertainty estimates help identify areas where the model is less confident in its predictions, which is crucial for: |
|
- Risk assessment in coastal hazard modeling |
|
- Climate change impact analysis |
|
- Tsunami propagation simulation |
|
|
|
## Model Performance |
|
|
|
| Model | SSIM | PSNR | MSE | MAE | UWidth | CalErr | |
|
|-------|------|------|-----|-----|--------|--------| |
|
| UA-VQ-VAE | 0.9433 | 26.8779 | 0.0021 | 0.0317 | 0.1046 | 0.0664 | |
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
if model_loaded: |
|
print("Model loaded successfully. Starting Gradio interface.") |
|
else: |
|
print("Warning: Model not loaded. Demo will display errors when processing files.") |
|
|
|
demo.launch() |