import os import numpy as np import matplotlib.pyplot as plt from matplotlib import cm import gradio as gr from PIL import Image from pathlib import Path # Import the inference module from inference import BathymetrySuperResolution # Define checkpoint and config paths CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "checkpoints") MODEL_CHECKPOINT = os.path.join(CHECKPOINT_DIR, "calibrated.pth") CONFIG_PATH = os.environ.get("CONFIG_PATH", "config.json") # Initialize model try: model = BathymetrySuperResolution( model_type="vqvae", checkpoint_path=MODEL_CHECKPOINT, config_path=CONFIG_PATH ) model_loaded = True except Exception as e: print(f"Error loading model: {str(e)}") model = None model_loaded = False def process_upload(file, confidence_level, block_size, model_type): """Process uploaded bathymetry file""" if file is None: return None, "Please upload a file." try: # Check if the model is loaded if not model_loaded: return None, "Model not loaded. Please check server logs." # Load the data if file.name.endswith('.npy'): data = np.load(file.name) else: # Try to load as an image img = Image.open(file.name).convert('L') data = np.array(img) # Update model configuration if needed if model.config['model_type'] != model_type or model.config['model_config']['block_size'] != block_size: # In a real app, you would reload the model or adjust the configuration pass # Run the prediction prediction, lower_bound, upper_bound = model.predict( data, with_uncertainty=True, confidence_level=confidence_level/100.0 # Convert percentage to fraction ) # Calculate uncertainty width uncertainty_width = model.get_uncertainty_width(lower_bound, upper_bound) # Create visualization fig = plt.figure(figsize=(15, 10)) # Original input (resized to 32x32 if needed) ax1 = fig.add_subplot(231) if data.shape != (32, 32): from scipy.ndimage import zoom zoom_factor = 32 / max(data.shape) input_data = zoom(data, zoom_factor) else: input_data = data im1 = ax1.imshow(input_data, cmap=cm.viridis) ax1.set_title("Input (32x32)") plt.colorbar(im1, ax=ax1) # Super-resolution output ax2 = fig.add_subplot(232) im2 = ax2.imshow(prediction[0, 0], cmap=cm.viridis) ax2.set_title("Super-Resolution (64x64)") plt.colorbar(im2, ax=ax2) # Lower bound ax3 = fig.add_subplot(233) im3 = ax3.imshow(lower_bound[0, 0], cmap=cm.viridis) ax3.set_title(f"Lower Bound ({confidence_level}% CI)") plt.colorbar(im3, ax=ax3) # Upper bound ax4 = fig.add_subplot(234) im4 = ax4.imshow(upper_bound[0, 0], cmap=cm.viridis) ax4.set_title(f"Upper Bound ({confidence_level}% CI)") plt.colorbar(im4, ax=ax4) # Uncertainty width visualization ax5 = fig.add_subplot(235) uncertainty_map = upper_bound[0, 0] - lower_bound[0, 0] im5 = ax5.imshow(uncertainty_map, cmap='hot') ax5.set_title("Uncertainty Width") plt.colorbar(im5, ax=ax5) # 3D surface plot ax6 = fig.add_subplot(236, projection='3d') x = np.arange(0, prediction.shape[2]) y = np.arange(0, prediction.shape[3]) X, Y = np.meshgrid(x, y) surf = ax6.plot_surface(X, Y, prediction[0, 0], cmap=cm.viridis, linewidth=0, antialiased=True) ax6.set_title("3D Bathymetry") plt.tight_layout() # Return the figure and a summary text summary = f""" **Super-Resolution Results:** - **Model Type**: {model_type.upper()} - **Block Size**: {block_size}×{block_size} - **Confidence Level**: {confidence_level}% - **Average Uncertainty Width**: {uncertainty_width:.4f} - **Input Shape**: {data.shape} - **Output Shape**: {prediction.shape[2:]} """ return fig, summary except Exception as e: import traceback traceback.print_exc() return None, f"Error processing file: {str(e)}" def create_sample_data(): """Create a sample bathymetry data file for demonstration""" # Create a synthetic bathymetry profile with features x = np.linspace(0, 1, 32) y = np.linspace(0, 1, 32) xx, yy = np.meshgrid(x, y) # Create a surface with a ridge and a valley z = -4000 + 500 * np.sin(10 * xx) * np.cos(8 * yy) + 300 * np.exp(-((xx-0.3)**2 + (yy-0.7)**2)/0.1) # Save to a temporary file sample_dir = Path("samples") sample_dir.mkdir(exist_ok=True) sample_path = sample_dir / "sample.npy" np.save(sample_path, z) return str(sample_path) # Create the Gradio interface with gr.Blocks(title="Bathymetry Super-Resolution") as demo: gr.Markdown(""" # Bathymetry Super-Resolution with Uncertainty Quantification This application demonstrates super-resolution of ocean floor (bathymetry) data with uncertainty estimates. Upload a bathymetry file (NPY or image) to see the enhanced resolution output with confidence intervals. The model uses a **Vector Quantized Variational Autoencoder (VQ-VAE)** with **block-based uncertainty quantification**. """) with gr.Row(): with gr.Column(): input_file = gr.File(label="Upload Bathymetry File (.npy or image)") with gr.Row(): confidence_level = gr.Slider( minimum=80, maximum=99, value=95, step=1, label="Confidence Level (%)" ) block_size = gr.Dropdown( choices=[1, 2, 4, 8, 64], value=4, label="Block Size" ) model_type = gr.Dropdown( choices=["vqvae", "srcnn", "gan"], value="vqvae", label="Model Type" ) with gr.Row(): process_btn = gr.Button("Generate Super-Resolution") sample_btn = gr.Button("Load Sample Data") with gr.Column(): output_plots = gr.Plot(label="Super-Resolution Results") output_text = gr.Markdown(label="Summary") # Set up button actions process_btn.click( fn=process_upload, inputs=[input_file, confidence_level, block_size, model_type], outputs=[output_plots, output_text] ) # Sample data generation sample_btn.click( fn=lambda: gr.update(value=create_sample_data()), inputs=None, outputs=input_file ) gr.Markdown(""" ## About This Model This model enhances the resolution of bathymetric data from 32×32 to 64×64 while providing uncertainty estimates. It was trained on bathymetry data from multiple ocean regions including the Eastern Pacific Basin, Western Pacific Region, and Indian Ocean Basin. The uncertainty estimates help identify areas where the model is less confident in its predictions, which is crucial for: - Risk assessment in coastal hazard modeling - Climate change impact analysis - Tsunami propagation simulation ## Model Performance | Model | SSIM | PSNR | MSE | MAE | UWidth | CalErr | |-------|------|------|-----|-----|--------|--------| | UA-VQ-VAE | 0.9433 | 26.8779 | 0.0021 | 0.0317 | 0.1046 | 0.0664 | """) # Launch the demo if __name__ == "__main__": if model_loaded: print("Model loaded successfully. Starting Gradio interface.") else: print("Warning: Model not loaded. Demo will display errors when processing files.") demo.launch()