Jose Marie Antonio Minoza commited on
Commit
95f0e22
·
1 Parent(s): 260f5fa

Initial commit

Browse files
Files changed (8) hide show
  1. app.py +228 -0
  2. checkpoints/calibrated.pth +3 -0
  3. config.json +16 -0
  4. inference.py +207 -0
  5. models/loader.py +147 -0
  6. models/uncertainty.py +188 -0
  7. models/vqvae.py +262 -0
  8. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from matplotlib import cm
5
+ import gradio as gr
6
+ from PIL import Image
7
+ from pathlib import Path
8
+
9
+ # Import the inference module
10
+ from inference import BathymetrySuperResolution
11
+
12
+ # Define checkpoint and config paths
13
+ CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "checkpoints")
14
+ MODEL_CHECKPOINT = os.path.join(CHECKPOINT_DIR, "calibrated.pth")
15
+ CONFIG_PATH = os.environ.get("CONFIG_PATH", "config.json")
16
+
17
+ # Initialize model
18
+ try:
19
+ model = BathymetrySuperResolution(
20
+ model_type="vqvae",
21
+ checkpoint_path=MODEL_CHECKPOINT,
22
+ config_path=CONFIG_PATH
23
+ )
24
+ model_loaded = True
25
+ except Exception as e:
26
+ print(f"Error loading model: {str(e)}")
27
+ model = None
28
+ model_loaded = False
29
+
30
+ def process_upload(file, confidence_level, block_size, model_type):
31
+ """Process uploaded bathymetry file"""
32
+ if file is None:
33
+ return None, "Please upload a file."
34
+
35
+ try:
36
+ # Check if the model is loaded
37
+ if not model_loaded:
38
+ return None, "Model not loaded. Please check server logs."
39
+
40
+ # Load the data
41
+ if file.name.endswith('.npy'):
42
+ data = np.load(file.name)
43
+ else:
44
+ # Try to load as an image
45
+ img = Image.open(file.name).convert('L')
46
+ data = np.array(img)
47
+
48
+ # Update model configuration if needed
49
+ if model.config['model_type'] != model_type or model.config['model_config']['block_size'] != block_size:
50
+ # In a real app, you would reload the model or adjust the configuration
51
+ pass
52
+
53
+ # Run the prediction
54
+ prediction, lower_bound, upper_bound = model.predict(
55
+ data,
56
+ with_uncertainty=True,
57
+ confidence_level=confidence_level/100.0 # Convert percentage to fraction
58
+ )
59
+
60
+ # Calculate uncertainty width
61
+ uncertainty_width = model.get_uncertainty_width(lower_bound, upper_bound)
62
+
63
+ # Create visualization
64
+ fig = plt.figure(figsize=(15, 10))
65
+
66
+ # Original input (resized to 32x32 if needed)
67
+ ax1 = fig.add_subplot(231)
68
+ if data.shape != (32, 32):
69
+ from scipy.ndimage import zoom
70
+ zoom_factor = 32 / max(data.shape)
71
+ input_data = zoom(data, zoom_factor)
72
+ else:
73
+ input_data = data
74
+ im1 = ax1.imshow(input_data, cmap=cm.viridis)
75
+ ax1.set_title("Input (32x32)")
76
+ plt.colorbar(im1, ax=ax1)
77
+
78
+ # Super-resolution output
79
+ ax2 = fig.add_subplot(232)
80
+ im2 = ax2.imshow(prediction[0, 0], cmap=cm.viridis)
81
+ ax2.set_title("Super-Resolution (64x64)")
82
+ plt.colorbar(im2, ax=ax2)
83
+
84
+ # Lower bound
85
+ ax3 = fig.add_subplot(233)
86
+ im3 = ax3.imshow(lower_bound[0, 0], cmap=cm.viridis)
87
+ ax3.set_title(f"Lower Bound ({confidence_level}% CI)")
88
+ plt.colorbar(im3, ax=ax3)
89
+
90
+ # Upper bound
91
+ ax4 = fig.add_subplot(234)
92
+ im4 = ax4.imshow(upper_bound[0, 0], cmap=cm.viridis)
93
+ ax4.set_title(f"Upper Bound ({confidence_level}% CI)")
94
+ plt.colorbar(im4, ax=ax4)
95
+
96
+ # Uncertainty width visualization
97
+ ax5 = fig.add_subplot(235)
98
+ uncertainty_map = upper_bound[0, 0] - lower_bound[0, 0]
99
+ im5 = ax5.imshow(uncertainty_map, cmap='hot')
100
+ ax5.set_title("Uncertainty Width")
101
+ plt.colorbar(im5, ax=ax5)
102
+
103
+ # 3D surface plot
104
+ ax6 = fig.add_subplot(236, projection='3d')
105
+ x = np.arange(0, prediction.shape[2])
106
+ y = np.arange(0, prediction.shape[3])
107
+ X, Y = np.meshgrid(x, y)
108
+ surf = ax6.plot_surface(X, Y, prediction[0, 0], cmap=cm.viridis,
109
+ linewidth=0, antialiased=True)
110
+ ax6.set_title("3D Bathymetry")
111
+
112
+ plt.tight_layout()
113
+
114
+ # Return the figure and a summary text
115
+ summary = f"""
116
+ **Super-Resolution Results:**
117
+ - **Model Type**: {model_type.upper()}
118
+ - **Block Size**: {block_size}×{block_size}
119
+ - **Confidence Level**: {confidence_level}%
120
+ - **Average Uncertainty Width**: {uncertainty_width:.4f}
121
+ - **Input Shape**: {data.shape}
122
+ - **Output Shape**: {prediction.shape[2:]}
123
+ """
124
+
125
+ return fig, summary
126
+
127
+ except Exception as e:
128
+ import traceback
129
+ traceback.print_exc()
130
+ return None, f"Error processing file: {str(e)}"
131
+
132
+ def create_sample_data():
133
+ """Create a sample bathymetry data file for demonstration"""
134
+ # Create a synthetic bathymetry profile with features
135
+ x = np.linspace(0, 1, 32)
136
+ y = np.linspace(0, 1, 32)
137
+ xx, yy = np.meshgrid(x, y)
138
+
139
+ # Create a surface with a ridge and a valley
140
+ z = -4000 + 500 * np.sin(10 * xx) * np.cos(8 * yy) + 300 * np.exp(-((xx-0.3)**2 + (yy-0.7)**2)/0.1)
141
+
142
+ # Save to a temporary file
143
+ sample_dir = Path("samples")
144
+ sample_dir.mkdir(exist_ok=True)
145
+ sample_path = sample_dir / "sample_bathymetry.npy"
146
+ np.save(sample_path, z)
147
+
148
+ return str(sample_path)
149
+
150
+ # Create the Gradio interface
151
+ with gr.Blocks(title="Bathymetry Super-Resolution") as demo:
152
+ gr.Markdown("""
153
+ # Bathymetry Super-Resolution with Uncertainty Quantification
154
+
155
+ This application demonstrates super-resolution of ocean floor (bathymetry) data with uncertainty estimates.
156
+ Upload a bathymetry file (NPY or image) to see the enhanced resolution output with confidence intervals.
157
+
158
+ The model uses a **Vector Quantized Variational Autoencoder (VQ-VAE)** with **block-based uncertainty quantification**.
159
+ """)
160
+
161
+ with gr.Row():
162
+ with gr.Column():
163
+ input_file = gr.File(label="Upload Bathymetry File (.npy or image)")
164
+
165
+ with gr.Row():
166
+ confidence_level = gr.Slider(
167
+ minimum=80, maximum=99, value=95, step=1,
168
+ label="Confidence Level (%)"
169
+ )
170
+
171
+ block_size = gr.Dropdown(
172
+ choices=[1, 2, 4, 8, 64], value=4,
173
+ label="Block Size"
174
+ )
175
+
176
+ model_type = gr.Dropdown(
177
+ choices=["vqvae", "srcnn", "gan"], value="vqvae",
178
+ label="Model Type"
179
+ )
180
+
181
+ with gr.Row():
182
+ process_btn = gr.Button("Generate Super-Resolution")
183
+ sample_btn = gr.Button("Load Sample Data")
184
+
185
+ with gr.Column():
186
+ output_plots = gr.Plot(label="Super-Resolution Results")
187
+ output_text = gr.Markdown(label="Summary")
188
+
189
+ # Set up button actions
190
+ process_btn.click(
191
+ fn=process_upload,
192
+ inputs=[input_file, confidence_level, block_size, model_type],
193
+ outputs=[output_plots, output_text]
194
+ )
195
+
196
+ # Sample data generation
197
+ sample_btn.click(
198
+ fn=lambda: gr.update(value=create_sample_data()),
199
+ inputs=None,
200
+ outputs=input_file
201
+ )
202
+
203
+ gr.Markdown("""
204
+ ## About This Model
205
+
206
+ This model enhances the resolution of bathymetric data from 32×32 to 64×64 while providing uncertainty estimates.
207
+ It was trained on bathymetry data from multiple ocean regions including the Eastern Pacific Basin, Western Pacific Region, and Indian Ocean Basin.
208
+
209
+ The uncertainty estimates help identify areas where the model is less confident in its predictions, which is crucial for:
210
+ - Risk assessment in coastal hazard modeling
211
+ - Climate change impact analysis
212
+ - Tsunami propagation simulation
213
+
214
+ ## Model Performance
215
+
216
+ | Model | SSIM | PSNR | MSE | MAE | UWidth | CalErr |
217
+ |-------|------|------|-----|-----|--------|--------|
218
+ | UA-VQ-VAE | 0.9433 | 26.8779 | 0.0021 | 0.0317 | 0.1046 | 0.0664 |
219
+ """)
220
+
221
+ # Launch the demo
222
+ if __name__ == "__main__":
223
+ if model_loaded:
224
+ print("Model loaded successfully. Starting Gradio interface.")
225
+ else:
226
+ print("Warning: Model not loaded. Demo will display errors when processing files.")
227
+
228
+ demo.launch()
checkpoints/calibrated.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c2ba632b12fe06e6c684c964ea07424b074e36ed533c3b03fa3bb4e8bf1c67e
3
+ size 235336891
config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "vqvae",
3
+ "model_config": {
4
+ "in_channels": 1,
5
+ "hidden_dims": [32, 64, 128, 256],
6
+ "num_embeddings": 512,
7
+ "embedding_dim": 256,
8
+ "block_size": 4
9
+ },
10
+ "normalization": {
11
+ "mean": -3911.3894,
12
+ "std": 1172.8374,
13
+ "min": 0.0,
14
+ "max": 1.0
15
+ }
16
+ }
inference.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch.nn.functional as F
6
+ import json
7
+
8
+ # Import your model components
9
+ from models.loader import ModelLoader
10
+ from models.uncertainty import BlockUncertaintyTracker
11
+
12
+ class BathymetrySuperResolution:
13
+ """
14
+ Bathymetry super-resolution model with uncertainty estimation
15
+ """
16
+ def __init__(self, model_type="vqvae", checkpoint_path=None, config_path=None):
17
+ """
18
+ Initialize the super-resolution model with uncertainty awareness
19
+
20
+ Args:
21
+ model_type: Type of model ('srcnn', 'gan', or 'vqvae')
22
+ checkpoint_path: Path to model checkpoint
23
+ config_path: Path to configuration file
24
+ """
25
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
+
27
+ # Load config if provided
28
+ if config_path is not None and os.path.exists(config_path):
29
+ with open(config_path, 'r') as f:
30
+ self.config = json.load(f)
31
+ else:
32
+ # Default configuration
33
+ self.config = {
34
+ "model_type": model_type,
35
+ "model_config": {
36
+ "in_channels": 1,
37
+ "hidden_dims": [32, 64, 128, 256],
38
+ "num_embeddings": 512,
39
+ "embedding_dim": 256,
40
+ "block_size": 4
41
+ },
42
+ "normalization": {
43
+ "mean": -3911.3894,
44
+ "std": 1172.8374,
45
+ "min": 0.0,
46
+ "max": 1.0
47
+ }
48
+ }
49
+
50
+ # Initialize model loader
51
+ self.model_loader = ModelLoader()
52
+
53
+ # Load model
54
+ if checkpoint_path is not None and os.path.exists(checkpoint_path):
55
+ self.model = self.model_loader.load_model(
56
+ self.config['model_type'],
57
+ checkpoint_path,
58
+ config_overrides=self.config.get('model_config', {})
59
+ )
60
+ else:
61
+ raise ValueError("Checkpoint path not provided or invalid")
62
+
63
+ # Ensure model is in eval mode
64
+ self.model.eval()
65
+
66
+ # Load normalization parameters
67
+ self.mean = self.config['normalization']['mean']
68
+ self.std = self.config['normalization']['std']
69
+ self.min_val = self.config['normalization']['min']
70
+ self.max_val = self.config['normalization']['max']
71
+
72
+ def preprocess(self, data):
73
+ """
74
+ Preprocess input data for the model
75
+
76
+ Args:
77
+ data: Input array/image (can be numpy array, PIL Image, or tensor)
78
+
79
+ Returns:
80
+ Preprocessed tensor
81
+ """
82
+ # Convert PIL Image to numpy if needed
83
+ if isinstance(data, Image.Image):
84
+ data = np.array(data)
85
+
86
+ # Convert numpy to tensor if needed
87
+ if isinstance(data, np.ndarray):
88
+ tensor = torch.from_numpy(data).float()
89
+ else:
90
+ tensor = data.float()
91
+
92
+ # Add batch and channel dimensions if needed
93
+ if len(tensor.shape) == 2:
94
+ tensor = tensor.unsqueeze(0).unsqueeze(0)
95
+ elif len(tensor.shape) == 3:
96
+ tensor = tensor.unsqueeze(0)
97
+
98
+ # Apply normalization
99
+ tensor = (tensor - self.mean) / (self.std + 1e-8)
100
+ tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min() + 1e-8)
101
+
102
+ # Resize if needed (to 32x32)
103
+ if tensor.shape[-1] != 32 or tensor.shape[-2] != 32:
104
+ tensor = F.interpolate(
105
+ tensor,
106
+ size=(32, 32),
107
+ mode='bicubic',
108
+ align_corners=False
109
+ )
110
+
111
+ return tensor.to(self.device)
112
+
113
+ def denormalize(self, tensor):
114
+ """
115
+ Denormalize output tensor
116
+
117
+ Args:
118
+ tensor: Output tensor from model
119
+
120
+ Returns:
121
+ Denormalized tensor in original data range
122
+ """
123
+ # Scale from [0,1] back to original range
124
+ tensor = tensor * (self.max_val - self.min_val) + self.min_val
125
+
126
+ # Restore original scale
127
+ tensor = tensor * self.std + self.mean
128
+
129
+ return tensor
130
+
131
+ def predict(self, data, with_uncertainty=True, confidence_level=0.95):
132
+ """
133
+ Generate super-resolution output with uncertainty bounds
134
+
135
+ Args:
136
+ data: Input data (can be numpy array, PIL Image, or tensor)
137
+ with_uncertainty: Whether to include uncertainty bounds
138
+ confidence_level: Confidence level for uncertainty bounds
139
+
140
+ Returns:
141
+ Tuple of (prediction, lower_bound, upper_bound) if with_uncertainty=True
142
+ or just prediction otherwise
143
+ """
144
+ # Preprocess input
145
+ input_tensor = self.preprocess(data)
146
+
147
+ with torch.no_grad():
148
+ # Run model inference
149
+ if with_uncertainty and hasattr(self.model, 'predict_with_uncertainty'):
150
+ prediction, lower_bound, upper_bound = self.model.predict_with_uncertainty(
151
+ input_tensor, confidence_level
152
+ )
153
+
154
+ # Denormalize outputs
155
+ prediction = self.denormalize(prediction)
156
+ lower_bound = self.denormalize(lower_bound) if lower_bound is not None else None
157
+ upper_bound = self.denormalize(upper_bound) if upper_bound is not None else None
158
+
159
+ # Convert to numpy
160
+ prediction = prediction.cpu().numpy()
161
+ lower_bound = lower_bound.cpu().numpy() if lower_bound is not None else None
162
+ upper_bound = upper_bound.cpu().numpy() if upper_bound is not None else None
163
+
164
+ return prediction, lower_bound, upper_bound
165
+ else:
166
+ # Standard inference
167
+ prediction = self.model(input_tensor)
168
+
169
+ # Denormalize
170
+ prediction = self.denormalize(prediction)
171
+
172
+ # Convert to numpy
173
+ prediction = prediction.cpu().numpy()
174
+
175
+ return prediction
176
+
177
+ def load_npy(self, file_path):
178
+ """
179
+ Load bathymetry data from numpy file
180
+
181
+ Args:
182
+ file_path: Path to .npy file
183
+
184
+ Returns:
185
+ Numpy array containing bathymetry data
186
+ """
187
+ try:
188
+ return np.load(file_path)
189
+ except Exception as e:
190
+ raise ValueError(f"Error loading numpy file: {str(e)}")
191
+
192
+ @staticmethod
193
+ def get_uncertainty_width(lower_bound, upper_bound):
194
+ """
195
+ Calculate uncertainty width (difference between upper and lower bounds)
196
+
197
+ Args:
198
+ lower_bound: Lower uncertainty bound
199
+ upper_bound: Upper uncertainty bound
200
+
201
+ Returns:
202
+ Uncertainty width
203
+ """
204
+ if lower_bound is None or upper_bound is None:
205
+ return None
206
+
207
+ return np.mean(upper_bound - lower_bound)
models/loader.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from enum import Enum
4
+ from typing import Dict, Optional
5
+
6
+ class ModelType(Enum):
7
+ SRCNN = 'srcnn'
8
+ GAN = 'gan'
9
+ VQVAE = 'vqvae'
10
+
11
+ class ModelLoader:
12
+ """
13
+ Loader for different super-resolution model architectures
14
+ """
15
+ def __init__(self):
16
+ # Base model configurations
17
+ self.model_configs = {
18
+ ModelType.SRCNN: {
19
+ "in_channels": 1,
20
+ "hidden_channels": 64,
21
+ "num_residual_blocks": 8,
22
+ "num_upsamples": 1,
23
+ "block_size": 4
24
+ },
25
+ ModelType.GAN: {
26
+ "in_channels": 1,
27
+ "hidden_channels": 64,
28
+ "num_rrdb_blocks": 8,
29
+ "growth_channels": 32,
30
+ "num_upsamples": 1,
31
+ "block_size": 4
32
+ },
33
+ ModelType.VQVAE: {
34
+ "in_channels": 1,
35
+ "hidden_dims": [32, 64, 128, 256],
36
+ "num_embeddings": 512,
37
+ "embedding_dim": 256,
38
+ "block_size": 4
39
+ }
40
+ }
41
+
42
+ self.model_registry = {}
43
+ self.loaded_models = {}
44
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
45
+
46
+ # Try to import model classes
47
+ try:
48
+ from .vqvae import VQVAE
49
+ self.model_registry[ModelType.VQVAE] = VQVAE
50
+ except ImportError:
51
+ print("Warning: VQVAE model implementation not found")
52
+
53
+ try:
54
+ from .cnn import CNN
55
+ self.model_registry[ModelType.SRCNN] = CNN
56
+ except ImportError:
57
+ print("Warning: CNN model implementation not found")
58
+
59
+ try:
60
+ from .gan import UncertainESRGAN
61
+ self.model_registry[ModelType.GAN] = UncertainESRGAN
62
+ except ImportError:
63
+ print("Warning: GAN model implementation not found")
64
+
65
+ def load_model(self, model_type: str, checkpoint_path: str, config_overrides: Optional[Dict] = None):
66
+ """
67
+ Load a model with its checkpoint and optional configuration overrides
68
+
69
+ Args:
70
+ model_type: Type of model to load ('srcnn', 'gan', or 'vqvae')
71
+ checkpoint_path: Path to model checkpoint
72
+ config_overrides: Optional dictionary of configuration overrides
73
+
74
+ Returns:
75
+ Loaded model or None if loading fails
76
+ """
77
+ try:
78
+ # Convert string to enum
79
+ model_type = ModelType(model_type.lower())
80
+
81
+ # Check if model implementation is available
82
+ if model_type not in self.model_registry:
83
+ raise ValueError(f"Model type {model_type.value} is not available")
84
+
85
+ # Get base config and apply overrides if provided
86
+ model_config = self.model_configs[model_type].copy()
87
+ if config_overrides:
88
+ model_config.update(config_overrides)
89
+
90
+ # Initialize model with potentially modified config
91
+ model_class = self.model_registry[model_type]
92
+ model = model_class(**model_config)
93
+
94
+ # Move model to device
95
+ model = model.to(self.device)
96
+
97
+ # Load checkpoint
98
+ try:
99
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
100
+ model.load_state_dict(checkpoint['state_dict'], strict=False)
101
+
102
+ # Load uncertainty tracker state if available
103
+ if hasattr(model, 'uncertainty_tracker'):
104
+ model.uncertainty_tracker.calibrated = checkpoint.get('calibrated', False)
105
+
106
+ if model.uncertainty_tracker.calibrated:
107
+ if 'block_scale_means' in checkpoint:
108
+ model.uncertainty_tracker.block_scale_means = checkpoint['block_scale_means']
109
+ if 'block_scale_stds' in checkpoint:
110
+ model.uncertainty_tracker.block_scale_stds = checkpoint['block_scale_stds']
111
+
112
+ print(f"Successfully loaded {model_type.value} model from {checkpoint_path}")
113
+ except Exception as e:
114
+ print(f"Warning: Could not load checkpoint. Using untrained model. Error: {e}")
115
+
116
+ # Store model
117
+ model.eval()
118
+ self.loaded_models[model_type] = model
119
+
120
+ return model
121
+
122
+ except Exception as e:
123
+ print(f"Error loading model: {str(e)}")
124
+ return None
125
+
126
+ def get_model(self, model_type: str):
127
+ """Get a loaded model by type"""
128
+ try:
129
+ model_type = ModelType(model_type.lower())
130
+ return self.loaded_models.get(model_type)
131
+ except:
132
+ return None
133
+
134
+ def unload_model(self, model_type: str):
135
+ """Unload a specific model"""
136
+ try:
137
+ model_type = ModelType(model_type.lower())
138
+ if model_type in self.loaded_models:
139
+ del self.loaded_models[model_type]
140
+ torch.cuda.empty_cache()
141
+ except:
142
+ pass
143
+
144
+ def unload_all_models(self):
145
+ """Unload all loaded models"""
146
+ self.loaded_models.clear()
147
+ torch.cuda.empty_cache()
models/uncertainty.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class BlockUncertaintyTracker(nn.Module):
6
+ """
7
+ Track and estimate uncertainty at block level for bathymetry super-resolution
8
+ """
9
+ def __init__(self, block_size=4, alpha=0.1, decay=0.99, eps=1e-5):
10
+ """
11
+ Initialize block-wise uncertainty tracker
12
+
13
+ Args:
14
+ block_size: Size of spatial blocks for uncertainty estimation
15
+ alpha: Quantile parameter for uncertainty bounds
16
+ decay: EMA decay factor for tracking statistics
17
+ eps: Small value for numerical stability
18
+ """
19
+ super().__init__()
20
+ self.block_size = block_size
21
+ self.decay = decay
22
+ self.alpha = alpha
23
+ self.eps = eps
24
+
25
+ # Initialize unfold layer for block extraction
26
+ self.unfold = nn.Unfold(kernel_size=block_size, stride=block_size)
27
+
28
+ # Register buffers with initial values
29
+ self.register_buffer('ema_errors', None)
30
+ self.register_buffer('ema_quantile', None)
31
+
32
+ self.num_blocks_h = None
33
+ self.num_blocks_w = None
34
+
35
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
36
+
37
+ # Calibration statistics
38
+ self.calibrated = False
39
+ self.block_means = []
40
+ self.block_stds = []
41
+ self.block_scale_means = None
42
+ self.block_scale_stds = None
43
+
44
+ def _initialize_buffers(self, h, w, device):
45
+ """Initialize EMA buffers based on number of blocks in image"""
46
+ self.num_blocks_h = h // self.block_size
47
+ self.num_blocks_w = w // self.block_size
48
+ num_blocks = self.num_blocks_h * self.num_blocks_w
49
+
50
+ # Initialize buffers on the correct device
51
+ self.ema_errors = torch.zeros(num_blocks, device=device)
52
+ self.ema_quantile = torch.zeros(num_blocks, device=device)
53
+
54
+ def update(self, current_errors):
55
+ """Update EMA of errors and quantiles for each block"""
56
+ B, C, H, W = current_errors.shape
57
+ device = current_errors.device
58
+
59
+ # Initialize buffers if not done yet
60
+ if self.ema_errors is None:
61
+ self._initialize_buffers(H, W, device)
62
+
63
+ # Unfold into blocks
64
+ blocks = self.unfold(current_errors)
65
+ block_errors = blocks.transpose(1, 2)
66
+ block_errors = block_errors.reshape(-1, self.num_blocks_h * self.num_blocks_w, self.block_size * self.block_size)
67
+
68
+ with torch.no_grad():
69
+ # Compute mean error per block
70
+ block_mean_errors = block_errors.mean(dim=-1) # [B, num_blocks]
71
+
72
+ # Update EMA errors for each block
73
+ block_means = block_mean_errors.mean(dim=0) # Average across batch
74
+ block_means = block_means.to(device) # Ensure on correct device
75
+ self.ema_errors = self.ema_errors.to(device) # Ensure on correct device
76
+ self.ema_errors.mul_(self.decay).add_(block_means * (1 - self.decay))
77
+
78
+ # Update quantiles for each block
79
+ block_quantiles = torch.quantile(block_errors, 1 - self.alpha, dim=-1) # [B, num_blocks]
80
+ quantile_means = block_quantiles.mean(dim=0) # Average across batch
81
+ quantile_means = quantile_means.to(device) # Ensure on correct device
82
+ self.ema_quantile = self.ema_quantile.to(device) # Ensure on correct device
83
+ self.ema_quantile.mul_(self.decay).add_(quantile_means * (1 - self.decay))
84
+
85
+ def get_uncertainty(self, errors):
86
+ """Calculate block-wise uncertainty scores"""
87
+ B, C, H, W = errors.shape
88
+ device = errors.device
89
+
90
+ # Initialize buffers if not done yet
91
+ if self.ema_errors is None:
92
+ self._initialize_buffers(H, W, device)
93
+
94
+ # Ensure buffers are on correct device
95
+ self.ema_errors = self.ema_errors.to(device)
96
+ self.ema_quantile = self.ema_quantile.to(device)
97
+
98
+ # Unfold into blocks
99
+ blocks = self.unfold(errors)
100
+
101
+ # Calculate uncertainty for each block
102
+ uncertainties = []
103
+ for i in range(self.num_blocks_h * self.num_blocks_w):
104
+ block = blocks[:, :, i].view(B, C, self.block_size, self.block_size)
105
+ uncertainty = block / (self.ema_quantile[i] + self.eps)
106
+ uncertainties.append(uncertainty)
107
+
108
+ # Reconstruct full image from blocks
109
+ uncertainty_blocks = torch.stack(uncertainties, dim=-1)
110
+ uncertainty_blocks = uncertainty_blocks.permute(0, 1, 4, 2, 3)
111
+
112
+ # Reshape to original image size
113
+ uncertainty_map = uncertainty_blocks.reshape(
114
+ B, C,
115
+ self.num_blocks_h, self.block_size,
116
+ self.num_blocks_w, self.block_size
117
+ ).permute(0, 1, 2, 4, 3, 5).reshape(B, C, H, W)
118
+
119
+ return uncertainty_map
120
+
121
+ def get_bounds(self, x, confidence_level=0.95):
122
+ """
123
+ Get prediction bounds based on calibrated statistics
124
+
125
+ Args:
126
+ x: Input tensor [B, C, H, W]
127
+ confidence_level: Confidence level for bounds
128
+
129
+ Returns:
130
+ tuple: (lower_bounds, upper_bounds)
131
+ """
132
+ if not self.calibrated:
133
+ print("Warning: Model not calibrated. Bounds may be inaccurate.")
134
+ # Return simple bounds based on mean error
135
+ return x * 0.9, x * 1.1
136
+
137
+ # Calculate z-score based on confidence level
138
+ z_scores = {
139
+ 0.99: 2.576,
140
+ 0.95: 1.96,
141
+ 0.90: 1.645,
142
+ 0.85: 1.440,
143
+ 0.80: 1.282
144
+ }
145
+ z_score = z_scores.get(confidence_level, 1.96)
146
+
147
+ # Get block-wise uncertainty
148
+ blocks = self.unfold(x) # [B, C*block_size*block_size, num_blocks]
149
+
150
+ # Calculate bounds for each block using calibrated statistics
151
+ B, C, H, W = x.shape
152
+ lower_bounds = []
153
+ upper_bounds = []
154
+
155
+ for i in range(self.num_blocks_h * self.num_blocks_w):
156
+ block = blocks[:, :, i].view(B, C, self.block_size, self.block_size)
157
+
158
+ # Use calibrated statistics to determine uncertainty
159
+ if hasattr(self, 'block_scale_stds') and self.block_scale_stds is not None:
160
+ uncertainty = z_score * self.block_scale_stds[i]
161
+ else:
162
+ # Fallback if calibration stats not available
163
+ uncertainty = 0.1 * block.mean()
164
+
165
+ lower_bound = torch.clamp(block - uncertainty, min=0.0)
166
+ upper_bound = torch.clamp(block + uncertainty, max=1.0)
167
+
168
+ lower_bounds.append(lower_bound)
169
+ upper_bounds.append(upper_bound)
170
+
171
+ # Reconstruct full image from blocks
172
+ lower_bounds = torch.stack(lower_bounds, dim=-1)
173
+ upper_bounds = torch.stack(upper_bounds, dim=-1)
174
+
175
+ # Reshape to original image size
176
+ lower_bounds = lower_bounds.permute(0, 1, 4, 2, 3).reshape(
177
+ B, C,
178
+ self.num_blocks_h, self.num_blocks_w,
179
+ self.block_size, self.block_size
180
+ ).permute(0, 1, 2, 4, 3, 5).reshape(B, C, H, W)
181
+
182
+ upper_bounds = upper_bounds.permute(0, 1, 4, 2, 3).reshape(
183
+ B, C,
184
+ self.num_blocks_h, self.num_blocks_w,
185
+ self.block_size, self.block_size
186
+ ).permute(0, 1, 2, 4, 3, 5).reshape(B, C, H, W)
187
+
188
+ return lower_bounds, upper_bounds
models/vqvae.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .uncertainty import BlockUncertaintyTracker
5
+
6
+ class ResidualAttentionBlock(nn.Module):
7
+ """Residual attention block for capturing spatial dependencies"""
8
+ def __init__(self, in_channels):
9
+ super().__init__()
10
+
11
+ # Trunk branch
12
+ self.trunk = nn.Sequential(
13
+ nn.ReflectionPad2d(1),
14
+ nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=0),
15
+ nn.BatchNorm2d(in_channels),
16
+ nn.SiLU(),
17
+ nn.ReflectionPad2d(1),
18
+ nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=0),
19
+ nn.BatchNorm2d(in_channels)
20
+ )
21
+
22
+ # Mask branch for attention
23
+ self.mask = nn.Sequential(
24
+ nn.AdaptiveAvgPool2d(1),
25
+ nn.Conv2d(in_channels, in_channels, kernel_size=1),
26
+ nn.SiLU(),
27
+ nn.Conv2d(in_channels, in_channels, kernel_size=1),
28
+ nn.Sigmoid()
29
+ )
30
+
31
+ def forward(self, x):
32
+ # Trunk branch
33
+ trunk_output = self.trunk(x)
34
+
35
+ # Mask branch for attention weights
36
+ attention = self.mask(x)
37
+
38
+ # Apply attention and residual connection
39
+ out = x + attention * trunk_output
40
+ return F.silu(out)
41
+
42
+ class VectorQuantizer(nn.Module):
43
+ """Vector quantizer for discrete latent representation"""
44
+ def __init__(self, n_embeddings=512, embedding_dim=256, beta=0.25):
45
+ super().__init__()
46
+ self.n_embeddings = n_embeddings
47
+ self.embedding_dim = embedding_dim
48
+ self.beta = beta
49
+
50
+ # Initialize embeddings
51
+ self.embeddings = nn.Parameter(torch.randn(n_embeddings, embedding_dim))
52
+ nn.init.uniform_(self.embeddings, -1.0 / n_embeddings, 1.0 / n_embeddings)
53
+
54
+ # Usage tracking
55
+ self.register_buffer('usage', torch.zeros(n_embeddings))
56
+
57
+ def forward(self, z):
58
+ # Reshape input for quantization
59
+ z_flattened = z.reshape(-1, self.embedding_dim)
60
+
61
+ # Calculate distances to embedding vectors
62
+ distances = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
63
+ torch.sum(self.embeddings**2, dim=1) - \
64
+ 2 * torch.matmul(z_flattened, self.embeddings.t())
65
+
66
+ # Find nearest embedding for each input vector
67
+ encoding_indices = torch.argmin(distances, dim=1)
68
+
69
+ # Update usage statistics
70
+ if self.training:
71
+ with torch.no_grad():
72
+ usage = torch.zeros_like(self.usage)
73
+ usage.scatter_add_(0, encoding_indices, torch.ones_like(encoding_indices, dtype=torch.float))
74
+ self.usage.mul_(0.99).add_(usage, alpha=0.01)
75
+
76
+ # Get quantized vectors
77
+ z_q = self.embeddings[encoding_indices].reshape(z.shape)
78
+
79
+ # Calculate loss terms
80
+ commitment_loss = F.mse_loss(z_q.detach(), z)
81
+ codebook_loss = F.mse_loss(z_q, z.detach())
82
+
83
+ # Combine losses
84
+ loss = codebook_loss + self.beta * commitment_loss
85
+
86
+ # Straight-through estimator
87
+ z_q = z + (z_q - z).detach()
88
+
89
+ if self.training:
90
+ return z_q, loss
91
+ else:
92
+ return z_q
93
+
94
+ class Encoder(nn.Module):
95
+ """Encoder for VQ-VAE model"""
96
+ def __init__(self, in_channels=1, hidden_dims=[32, 64, 128, 256], embedding_dim=256):
97
+ super().__init__()
98
+
99
+ # Initial conv layer
100
+ layers = [
101
+ nn.Conv2d(in_channels, hidden_dims[0], kernel_size=3, stride=1, padding=1),
102
+ nn.BatchNorm2d(hidden_dims[0]),
103
+ nn.SiLU()
104
+ ]
105
+
106
+ # Hidden layers with downsampling
107
+ for i in range(len(hidden_dims) - 1):
108
+ layers.extend([
109
+ nn.Conv2d(hidden_dims[i], hidden_dims[i+1], kernel_size=4, stride=2, padding=1),
110
+ nn.BatchNorm2d(hidden_dims[i+1]),
111
+ nn.SiLU()
112
+ ])
113
+
114
+ # Residual attention blocks
115
+ for _ in range(2):
116
+ layers.append(ResidualAttentionBlock(hidden_dims[-1]))
117
+
118
+ # Final projection to embedding dimension
119
+ layers.extend([
120
+ nn.Conv2d(hidden_dims[-1], embedding_dim, kernel_size=1),
121
+ nn.BatchNorm2d(embedding_dim)
122
+ ])
123
+
124
+ self.encoder = nn.Sequential(*layers)
125
+
126
+ def forward(self, x):
127
+ return self.encoder(x)
128
+
129
+ class Decoder(nn.Module):
130
+ """Decoder for VQ-VAE model"""
131
+ def __init__(self, embedding_dim=256, hidden_dims=[256, 128, 64, 32], out_channels=1):
132
+ super().__init__()
133
+
134
+ # Reverse hidden dims for decoder
135
+ hidden_dims = hidden_dims[::-1]
136
+
137
+ # Initial processing
138
+ layers = [
139
+ nn.Conv2d(embedding_dim, hidden_dims[0], kernel_size=3, stride=1, padding=1),
140
+ nn.BatchNorm2d(hidden_dims[0]),
141
+ nn.SiLU()
142
+ ]
143
+
144
+ # Residual attention blocks
145
+ for _ in range(2):
146
+ layers.append(ResidualAttentionBlock(hidden_dims[0]))
147
+
148
+ # Upsampling blocks
149
+ for i in range(len(hidden_dims) - 1):
150
+ layers.extend([
151
+ nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i+1],
152
+ kernel_size=4, stride=2, padding=1),
153
+ nn.BatchNorm2d(hidden_dims[i+1]),
154
+ nn.SiLU()
155
+ ])
156
+
157
+ # Final output layer
158
+ layers.append(
159
+ nn.Conv2d(hidden_dims[-1], out_channels, kernel_size=3, padding=1)
160
+ )
161
+ layers.append(nn.Sigmoid())
162
+
163
+ self.decoder = nn.Sequential(*layers)
164
+
165
+ def forward(self, x):
166
+ return self.decoder(x)
167
+
168
+ class VQVAE(nn.Module):
169
+ """
170
+ Vector Quantized Variational Autoencoder with uncertainty awareness
171
+ for bathymetry super-resolution
172
+ """
173
+ def __init__(self, in_channels=1, hidden_dims=[32, 64, 128, 256],
174
+ num_embeddings=512, embedding_dim=256, block_size=4, alpha=0.1):
175
+ super().__init__()
176
+
177
+ # Initialize block-wise uncertainty tracking
178
+ self.uncertainty_tracker = BlockUncertaintyTracker(
179
+ block_size=block_size,
180
+ alpha=alpha,
181
+ decay=0.99,
182
+ eps=1e-5
183
+ )
184
+
185
+ # Main model components
186
+ self.encoder = Encoder(
187
+ in_channels=in_channels,
188
+ hidden_dims=hidden_dims,
189
+ embedding_dim=embedding_dim
190
+ )
191
+
192
+ self.vq = VectorQuantizer(
193
+ n_embeddings=num_embeddings,
194
+ embedding_dim=embedding_dim,
195
+ beta=0.25
196
+ )
197
+
198
+ self.decoder = Decoder(
199
+ embedding_dim=embedding_dim,
200
+ hidden_dims=hidden_dims,
201
+ out_channels=in_channels
202
+ )
203
+
204
+ def forward(self, x):
205
+ """Forward pass through the model"""
206
+ # Encode
207
+ z = self.encoder(x)
208
+
209
+ # Vector quantization
210
+ if self.training:
211
+ z_q, vq_loss = self.vq(z)
212
+
213
+ # Decode
214
+ reconstruction = self.decoder(z_q)
215
+
216
+ return reconstruction, vq_loss
217
+ else:
218
+ z_q = self.vq(z)
219
+
220
+ # Decode
221
+ reconstruction = self.decoder(z_q)
222
+
223
+ return reconstruction
224
+
225
+ def train_forward(self, x, y):
226
+ """Training forward pass with uncertainty tracking"""
227
+ # Get reconstruction and VQ loss
228
+ reconstruction, vq_loss = self.forward(x)
229
+
230
+ # Calculate reconstruction error
231
+ error = torch.abs(reconstruction - y)
232
+
233
+ # Update uncertainty tracker
234
+ self.uncertainty_tracker.update(error)
235
+
236
+ # Get uncertainty map for loss weighting
237
+ uncertainty_map = self.uncertainty_tracker.get_uncertainty(error)
238
+
239
+ return reconstruction, vq_loss, uncertainty_map
240
+
241
+ def predict_with_uncertainty(self, x, confidence_level=0.95):
242
+ """
243
+ Forward pass with calibrated uncertainty bounds
244
+
245
+ Args:
246
+ x: Input tensor
247
+ confidence_level: Confidence level for bounds (default: 0.95)
248
+
249
+ Returns:
250
+ tuple: (reconstruction, lower_bounds, upper_bounds)
251
+ """
252
+ self.eval()
253
+ with torch.no_grad():
254
+ # Get reconstruction
255
+ reconstruction = self.forward(x)
256
+
257
+ # Get calibrated uncertainty bounds
258
+ lower_bounds, upper_bounds = self.uncertainty_tracker.get_bounds(
259
+ reconstruction, confidence_level
260
+ )
261
+
262
+ return reconstruction, lower_bounds, upper_bounds
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ numpy>=1.20.0
3
+ matplotlib>=3.5.0
4
+ gradio>=3.32.0
5
+ Pillow>=9.0.0
6
+ scipy>=1.8.0
7
+ tqdm>=4.62.0
8
+ huggingface_hub>=0.14.0
9
+ transformers>=4.30.0
10
+ pandas>=1.3.0
11
+ scikit-learn>=1.0.0