Jose Marie Antonio Minoza
commited on
Commit
·
95f0e22
1
Parent(s):
260f5fa
Initial commit
Browse files- app.py +228 -0
- checkpoints/calibrated.pth +3 -0
- config.json +16 -0
- inference.py +207 -0
- models/loader.py +147 -0
- models/uncertainty.py +188 -0
- models/vqvae.py +262 -0
- requirements.txt +11 -0
app.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from matplotlib import cm
|
5 |
+
import gradio as gr
|
6 |
+
from PIL import Image
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
# Import the inference module
|
10 |
+
from inference import BathymetrySuperResolution
|
11 |
+
|
12 |
+
# Define checkpoint and config paths
|
13 |
+
CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "checkpoints")
|
14 |
+
MODEL_CHECKPOINT = os.path.join(CHECKPOINT_DIR, "calibrated.pth")
|
15 |
+
CONFIG_PATH = os.environ.get("CONFIG_PATH", "config.json")
|
16 |
+
|
17 |
+
# Initialize model
|
18 |
+
try:
|
19 |
+
model = BathymetrySuperResolution(
|
20 |
+
model_type="vqvae",
|
21 |
+
checkpoint_path=MODEL_CHECKPOINT,
|
22 |
+
config_path=CONFIG_PATH
|
23 |
+
)
|
24 |
+
model_loaded = True
|
25 |
+
except Exception as e:
|
26 |
+
print(f"Error loading model: {str(e)}")
|
27 |
+
model = None
|
28 |
+
model_loaded = False
|
29 |
+
|
30 |
+
def process_upload(file, confidence_level, block_size, model_type):
|
31 |
+
"""Process uploaded bathymetry file"""
|
32 |
+
if file is None:
|
33 |
+
return None, "Please upload a file."
|
34 |
+
|
35 |
+
try:
|
36 |
+
# Check if the model is loaded
|
37 |
+
if not model_loaded:
|
38 |
+
return None, "Model not loaded. Please check server logs."
|
39 |
+
|
40 |
+
# Load the data
|
41 |
+
if file.name.endswith('.npy'):
|
42 |
+
data = np.load(file.name)
|
43 |
+
else:
|
44 |
+
# Try to load as an image
|
45 |
+
img = Image.open(file.name).convert('L')
|
46 |
+
data = np.array(img)
|
47 |
+
|
48 |
+
# Update model configuration if needed
|
49 |
+
if model.config['model_type'] != model_type or model.config['model_config']['block_size'] != block_size:
|
50 |
+
# In a real app, you would reload the model or adjust the configuration
|
51 |
+
pass
|
52 |
+
|
53 |
+
# Run the prediction
|
54 |
+
prediction, lower_bound, upper_bound = model.predict(
|
55 |
+
data,
|
56 |
+
with_uncertainty=True,
|
57 |
+
confidence_level=confidence_level/100.0 # Convert percentage to fraction
|
58 |
+
)
|
59 |
+
|
60 |
+
# Calculate uncertainty width
|
61 |
+
uncertainty_width = model.get_uncertainty_width(lower_bound, upper_bound)
|
62 |
+
|
63 |
+
# Create visualization
|
64 |
+
fig = plt.figure(figsize=(15, 10))
|
65 |
+
|
66 |
+
# Original input (resized to 32x32 if needed)
|
67 |
+
ax1 = fig.add_subplot(231)
|
68 |
+
if data.shape != (32, 32):
|
69 |
+
from scipy.ndimage import zoom
|
70 |
+
zoom_factor = 32 / max(data.shape)
|
71 |
+
input_data = zoom(data, zoom_factor)
|
72 |
+
else:
|
73 |
+
input_data = data
|
74 |
+
im1 = ax1.imshow(input_data, cmap=cm.viridis)
|
75 |
+
ax1.set_title("Input (32x32)")
|
76 |
+
plt.colorbar(im1, ax=ax1)
|
77 |
+
|
78 |
+
# Super-resolution output
|
79 |
+
ax2 = fig.add_subplot(232)
|
80 |
+
im2 = ax2.imshow(prediction[0, 0], cmap=cm.viridis)
|
81 |
+
ax2.set_title("Super-Resolution (64x64)")
|
82 |
+
plt.colorbar(im2, ax=ax2)
|
83 |
+
|
84 |
+
# Lower bound
|
85 |
+
ax3 = fig.add_subplot(233)
|
86 |
+
im3 = ax3.imshow(lower_bound[0, 0], cmap=cm.viridis)
|
87 |
+
ax3.set_title(f"Lower Bound ({confidence_level}% CI)")
|
88 |
+
plt.colorbar(im3, ax=ax3)
|
89 |
+
|
90 |
+
# Upper bound
|
91 |
+
ax4 = fig.add_subplot(234)
|
92 |
+
im4 = ax4.imshow(upper_bound[0, 0], cmap=cm.viridis)
|
93 |
+
ax4.set_title(f"Upper Bound ({confidence_level}% CI)")
|
94 |
+
plt.colorbar(im4, ax=ax4)
|
95 |
+
|
96 |
+
# Uncertainty width visualization
|
97 |
+
ax5 = fig.add_subplot(235)
|
98 |
+
uncertainty_map = upper_bound[0, 0] - lower_bound[0, 0]
|
99 |
+
im5 = ax5.imshow(uncertainty_map, cmap='hot')
|
100 |
+
ax5.set_title("Uncertainty Width")
|
101 |
+
plt.colorbar(im5, ax=ax5)
|
102 |
+
|
103 |
+
# 3D surface plot
|
104 |
+
ax6 = fig.add_subplot(236, projection='3d')
|
105 |
+
x = np.arange(0, prediction.shape[2])
|
106 |
+
y = np.arange(0, prediction.shape[3])
|
107 |
+
X, Y = np.meshgrid(x, y)
|
108 |
+
surf = ax6.plot_surface(X, Y, prediction[0, 0], cmap=cm.viridis,
|
109 |
+
linewidth=0, antialiased=True)
|
110 |
+
ax6.set_title("3D Bathymetry")
|
111 |
+
|
112 |
+
plt.tight_layout()
|
113 |
+
|
114 |
+
# Return the figure and a summary text
|
115 |
+
summary = f"""
|
116 |
+
**Super-Resolution Results:**
|
117 |
+
- **Model Type**: {model_type.upper()}
|
118 |
+
- **Block Size**: {block_size}×{block_size}
|
119 |
+
- **Confidence Level**: {confidence_level}%
|
120 |
+
- **Average Uncertainty Width**: {uncertainty_width:.4f}
|
121 |
+
- **Input Shape**: {data.shape}
|
122 |
+
- **Output Shape**: {prediction.shape[2:]}
|
123 |
+
"""
|
124 |
+
|
125 |
+
return fig, summary
|
126 |
+
|
127 |
+
except Exception as e:
|
128 |
+
import traceback
|
129 |
+
traceback.print_exc()
|
130 |
+
return None, f"Error processing file: {str(e)}"
|
131 |
+
|
132 |
+
def create_sample_data():
|
133 |
+
"""Create a sample bathymetry data file for demonstration"""
|
134 |
+
# Create a synthetic bathymetry profile with features
|
135 |
+
x = np.linspace(0, 1, 32)
|
136 |
+
y = np.linspace(0, 1, 32)
|
137 |
+
xx, yy = np.meshgrid(x, y)
|
138 |
+
|
139 |
+
# Create a surface with a ridge and a valley
|
140 |
+
z = -4000 + 500 * np.sin(10 * xx) * np.cos(8 * yy) + 300 * np.exp(-((xx-0.3)**2 + (yy-0.7)**2)/0.1)
|
141 |
+
|
142 |
+
# Save to a temporary file
|
143 |
+
sample_dir = Path("samples")
|
144 |
+
sample_dir.mkdir(exist_ok=True)
|
145 |
+
sample_path = sample_dir / "sample_bathymetry.npy"
|
146 |
+
np.save(sample_path, z)
|
147 |
+
|
148 |
+
return str(sample_path)
|
149 |
+
|
150 |
+
# Create the Gradio interface
|
151 |
+
with gr.Blocks(title="Bathymetry Super-Resolution") as demo:
|
152 |
+
gr.Markdown("""
|
153 |
+
# Bathymetry Super-Resolution with Uncertainty Quantification
|
154 |
+
|
155 |
+
This application demonstrates super-resolution of ocean floor (bathymetry) data with uncertainty estimates.
|
156 |
+
Upload a bathymetry file (NPY or image) to see the enhanced resolution output with confidence intervals.
|
157 |
+
|
158 |
+
The model uses a **Vector Quantized Variational Autoencoder (VQ-VAE)** with **block-based uncertainty quantification**.
|
159 |
+
""")
|
160 |
+
|
161 |
+
with gr.Row():
|
162 |
+
with gr.Column():
|
163 |
+
input_file = gr.File(label="Upload Bathymetry File (.npy or image)")
|
164 |
+
|
165 |
+
with gr.Row():
|
166 |
+
confidence_level = gr.Slider(
|
167 |
+
minimum=80, maximum=99, value=95, step=1,
|
168 |
+
label="Confidence Level (%)"
|
169 |
+
)
|
170 |
+
|
171 |
+
block_size = gr.Dropdown(
|
172 |
+
choices=[1, 2, 4, 8, 64], value=4,
|
173 |
+
label="Block Size"
|
174 |
+
)
|
175 |
+
|
176 |
+
model_type = gr.Dropdown(
|
177 |
+
choices=["vqvae", "srcnn", "gan"], value="vqvae",
|
178 |
+
label="Model Type"
|
179 |
+
)
|
180 |
+
|
181 |
+
with gr.Row():
|
182 |
+
process_btn = gr.Button("Generate Super-Resolution")
|
183 |
+
sample_btn = gr.Button("Load Sample Data")
|
184 |
+
|
185 |
+
with gr.Column():
|
186 |
+
output_plots = gr.Plot(label="Super-Resolution Results")
|
187 |
+
output_text = gr.Markdown(label="Summary")
|
188 |
+
|
189 |
+
# Set up button actions
|
190 |
+
process_btn.click(
|
191 |
+
fn=process_upload,
|
192 |
+
inputs=[input_file, confidence_level, block_size, model_type],
|
193 |
+
outputs=[output_plots, output_text]
|
194 |
+
)
|
195 |
+
|
196 |
+
# Sample data generation
|
197 |
+
sample_btn.click(
|
198 |
+
fn=lambda: gr.update(value=create_sample_data()),
|
199 |
+
inputs=None,
|
200 |
+
outputs=input_file
|
201 |
+
)
|
202 |
+
|
203 |
+
gr.Markdown("""
|
204 |
+
## About This Model
|
205 |
+
|
206 |
+
This model enhances the resolution of bathymetric data from 32×32 to 64×64 while providing uncertainty estimates.
|
207 |
+
It was trained on bathymetry data from multiple ocean regions including the Eastern Pacific Basin, Western Pacific Region, and Indian Ocean Basin.
|
208 |
+
|
209 |
+
The uncertainty estimates help identify areas where the model is less confident in its predictions, which is crucial for:
|
210 |
+
- Risk assessment in coastal hazard modeling
|
211 |
+
- Climate change impact analysis
|
212 |
+
- Tsunami propagation simulation
|
213 |
+
|
214 |
+
## Model Performance
|
215 |
+
|
216 |
+
| Model | SSIM | PSNR | MSE | MAE | UWidth | CalErr |
|
217 |
+
|-------|------|------|-----|-----|--------|--------|
|
218 |
+
| UA-VQ-VAE | 0.9433 | 26.8779 | 0.0021 | 0.0317 | 0.1046 | 0.0664 |
|
219 |
+
""")
|
220 |
+
|
221 |
+
# Launch the demo
|
222 |
+
if __name__ == "__main__":
|
223 |
+
if model_loaded:
|
224 |
+
print("Model loaded successfully. Starting Gradio interface.")
|
225 |
+
else:
|
226 |
+
print("Warning: Model not loaded. Demo will display errors when processing files.")
|
227 |
+
|
228 |
+
demo.launch()
|
checkpoints/calibrated.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5c2ba632b12fe06e6c684c964ea07424b074e36ed533c3b03fa3bb4e8bf1c67e
|
3 |
+
size 235336891
|
config.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "vqvae",
|
3 |
+
"model_config": {
|
4 |
+
"in_channels": 1,
|
5 |
+
"hidden_dims": [32, 64, 128, 256],
|
6 |
+
"num_embeddings": 512,
|
7 |
+
"embedding_dim": 256,
|
8 |
+
"block_size": 4
|
9 |
+
},
|
10 |
+
"normalization": {
|
11 |
+
"mean": -3911.3894,
|
12 |
+
"std": 1172.8374,
|
13 |
+
"min": 0.0,
|
14 |
+
"max": 1.0
|
15 |
+
}
|
16 |
+
}
|
inference.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import json
|
7 |
+
|
8 |
+
# Import your model components
|
9 |
+
from models.loader import ModelLoader
|
10 |
+
from models.uncertainty import BlockUncertaintyTracker
|
11 |
+
|
12 |
+
class BathymetrySuperResolution:
|
13 |
+
"""
|
14 |
+
Bathymetry super-resolution model with uncertainty estimation
|
15 |
+
"""
|
16 |
+
def __init__(self, model_type="vqvae", checkpoint_path=None, config_path=None):
|
17 |
+
"""
|
18 |
+
Initialize the super-resolution model with uncertainty awareness
|
19 |
+
|
20 |
+
Args:
|
21 |
+
model_type: Type of model ('srcnn', 'gan', or 'vqvae')
|
22 |
+
checkpoint_path: Path to model checkpoint
|
23 |
+
config_path: Path to configuration file
|
24 |
+
"""
|
25 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
26 |
+
|
27 |
+
# Load config if provided
|
28 |
+
if config_path is not None and os.path.exists(config_path):
|
29 |
+
with open(config_path, 'r') as f:
|
30 |
+
self.config = json.load(f)
|
31 |
+
else:
|
32 |
+
# Default configuration
|
33 |
+
self.config = {
|
34 |
+
"model_type": model_type,
|
35 |
+
"model_config": {
|
36 |
+
"in_channels": 1,
|
37 |
+
"hidden_dims": [32, 64, 128, 256],
|
38 |
+
"num_embeddings": 512,
|
39 |
+
"embedding_dim": 256,
|
40 |
+
"block_size": 4
|
41 |
+
},
|
42 |
+
"normalization": {
|
43 |
+
"mean": -3911.3894,
|
44 |
+
"std": 1172.8374,
|
45 |
+
"min": 0.0,
|
46 |
+
"max": 1.0
|
47 |
+
}
|
48 |
+
}
|
49 |
+
|
50 |
+
# Initialize model loader
|
51 |
+
self.model_loader = ModelLoader()
|
52 |
+
|
53 |
+
# Load model
|
54 |
+
if checkpoint_path is not None and os.path.exists(checkpoint_path):
|
55 |
+
self.model = self.model_loader.load_model(
|
56 |
+
self.config['model_type'],
|
57 |
+
checkpoint_path,
|
58 |
+
config_overrides=self.config.get('model_config', {})
|
59 |
+
)
|
60 |
+
else:
|
61 |
+
raise ValueError("Checkpoint path not provided or invalid")
|
62 |
+
|
63 |
+
# Ensure model is in eval mode
|
64 |
+
self.model.eval()
|
65 |
+
|
66 |
+
# Load normalization parameters
|
67 |
+
self.mean = self.config['normalization']['mean']
|
68 |
+
self.std = self.config['normalization']['std']
|
69 |
+
self.min_val = self.config['normalization']['min']
|
70 |
+
self.max_val = self.config['normalization']['max']
|
71 |
+
|
72 |
+
def preprocess(self, data):
|
73 |
+
"""
|
74 |
+
Preprocess input data for the model
|
75 |
+
|
76 |
+
Args:
|
77 |
+
data: Input array/image (can be numpy array, PIL Image, or tensor)
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
Preprocessed tensor
|
81 |
+
"""
|
82 |
+
# Convert PIL Image to numpy if needed
|
83 |
+
if isinstance(data, Image.Image):
|
84 |
+
data = np.array(data)
|
85 |
+
|
86 |
+
# Convert numpy to tensor if needed
|
87 |
+
if isinstance(data, np.ndarray):
|
88 |
+
tensor = torch.from_numpy(data).float()
|
89 |
+
else:
|
90 |
+
tensor = data.float()
|
91 |
+
|
92 |
+
# Add batch and channel dimensions if needed
|
93 |
+
if len(tensor.shape) == 2:
|
94 |
+
tensor = tensor.unsqueeze(0).unsqueeze(0)
|
95 |
+
elif len(tensor.shape) == 3:
|
96 |
+
tensor = tensor.unsqueeze(0)
|
97 |
+
|
98 |
+
# Apply normalization
|
99 |
+
tensor = (tensor - self.mean) / (self.std + 1e-8)
|
100 |
+
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min() + 1e-8)
|
101 |
+
|
102 |
+
# Resize if needed (to 32x32)
|
103 |
+
if tensor.shape[-1] != 32 or tensor.shape[-2] != 32:
|
104 |
+
tensor = F.interpolate(
|
105 |
+
tensor,
|
106 |
+
size=(32, 32),
|
107 |
+
mode='bicubic',
|
108 |
+
align_corners=False
|
109 |
+
)
|
110 |
+
|
111 |
+
return tensor.to(self.device)
|
112 |
+
|
113 |
+
def denormalize(self, tensor):
|
114 |
+
"""
|
115 |
+
Denormalize output tensor
|
116 |
+
|
117 |
+
Args:
|
118 |
+
tensor: Output tensor from model
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
Denormalized tensor in original data range
|
122 |
+
"""
|
123 |
+
# Scale from [0,1] back to original range
|
124 |
+
tensor = tensor * (self.max_val - self.min_val) + self.min_val
|
125 |
+
|
126 |
+
# Restore original scale
|
127 |
+
tensor = tensor * self.std + self.mean
|
128 |
+
|
129 |
+
return tensor
|
130 |
+
|
131 |
+
def predict(self, data, with_uncertainty=True, confidence_level=0.95):
|
132 |
+
"""
|
133 |
+
Generate super-resolution output with uncertainty bounds
|
134 |
+
|
135 |
+
Args:
|
136 |
+
data: Input data (can be numpy array, PIL Image, or tensor)
|
137 |
+
with_uncertainty: Whether to include uncertainty bounds
|
138 |
+
confidence_level: Confidence level for uncertainty bounds
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
Tuple of (prediction, lower_bound, upper_bound) if with_uncertainty=True
|
142 |
+
or just prediction otherwise
|
143 |
+
"""
|
144 |
+
# Preprocess input
|
145 |
+
input_tensor = self.preprocess(data)
|
146 |
+
|
147 |
+
with torch.no_grad():
|
148 |
+
# Run model inference
|
149 |
+
if with_uncertainty and hasattr(self.model, 'predict_with_uncertainty'):
|
150 |
+
prediction, lower_bound, upper_bound = self.model.predict_with_uncertainty(
|
151 |
+
input_tensor, confidence_level
|
152 |
+
)
|
153 |
+
|
154 |
+
# Denormalize outputs
|
155 |
+
prediction = self.denormalize(prediction)
|
156 |
+
lower_bound = self.denormalize(lower_bound) if lower_bound is not None else None
|
157 |
+
upper_bound = self.denormalize(upper_bound) if upper_bound is not None else None
|
158 |
+
|
159 |
+
# Convert to numpy
|
160 |
+
prediction = prediction.cpu().numpy()
|
161 |
+
lower_bound = lower_bound.cpu().numpy() if lower_bound is not None else None
|
162 |
+
upper_bound = upper_bound.cpu().numpy() if upper_bound is not None else None
|
163 |
+
|
164 |
+
return prediction, lower_bound, upper_bound
|
165 |
+
else:
|
166 |
+
# Standard inference
|
167 |
+
prediction = self.model(input_tensor)
|
168 |
+
|
169 |
+
# Denormalize
|
170 |
+
prediction = self.denormalize(prediction)
|
171 |
+
|
172 |
+
# Convert to numpy
|
173 |
+
prediction = prediction.cpu().numpy()
|
174 |
+
|
175 |
+
return prediction
|
176 |
+
|
177 |
+
def load_npy(self, file_path):
|
178 |
+
"""
|
179 |
+
Load bathymetry data from numpy file
|
180 |
+
|
181 |
+
Args:
|
182 |
+
file_path: Path to .npy file
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
Numpy array containing bathymetry data
|
186 |
+
"""
|
187 |
+
try:
|
188 |
+
return np.load(file_path)
|
189 |
+
except Exception as e:
|
190 |
+
raise ValueError(f"Error loading numpy file: {str(e)}")
|
191 |
+
|
192 |
+
@staticmethod
|
193 |
+
def get_uncertainty_width(lower_bound, upper_bound):
|
194 |
+
"""
|
195 |
+
Calculate uncertainty width (difference between upper and lower bounds)
|
196 |
+
|
197 |
+
Args:
|
198 |
+
lower_bound: Lower uncertainty bound
|
199 |
+
upper_bound: Upper uncertainty bound
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
Uncertainty width
|
203 |
+
"""
|
204 |
+
if lower_bound is None or upper_bound is None:
|
205 |
+
return None
|
206 |
+
|
207 |
+
return np.mean(upper_bound - lower_bound)
|
models/loader.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from enum import Enum
|
4 |
+
from typing import Dict, Optional
|
5 |
+
|
6 |
+
class ModelType(Enum):
|
7 |
+
SRCNN = 'srcnn'
|
8 |
+
GAN = 'gan'
|
9 |
+
VQVAE = 'vqvae'
|
10 |
+
|
11 |
+
class ModelLoader:
|
12 |
+
"""
|
13 |
+
Loader for different super-resolution model architectures
|
14 |
+
"""
|
15 |
+
def __init__(self):
|
16 |
+
# Base model configurations
|
17 |
+
self.model_configs = {
|
18 |
+
ModelType.SRCNN: {
|
19 |
+
"in_channels": 1,
|
20 |
+
"hidden_channels": 64,
|
21 |
+
"num_residual_blocks": 8,
|
22 |
+
"num_upsamples": 1,
|
23 |
+
"block_size": 4
|
24 |
+
},
|
25 |
+
ModelType.GAN: {
|
26 |
+
"in_channels": 1,
|
27 |
+
"hidden_channels": 64,
|
28 |
+
"num_rrdb_blocks": 8,
|
29 |
+
"growth_channels": 32,
|
30 |
+
"num_upsamples": 1,
|
31 |
+
"block_size": 4
|
32 |
+
},
|
33 |
+
ModelType.VQVAE: {
|
34 |
+
"in_channels": 1,
|
35 |
+
"hidden_dims": [32, 64, 128, 256],
|
36 |
+
"num_embeddings": 512,
|
37 |
+
"embedding_dim": 256,
|
38 |
+
"block_size": 4
|
39 |
+
}
|
40 |
+
}
|
41 |
+
|
42 |
+
self.model_registry = {}
|
43 |
+
self.loaded_models = {}
|
44 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
45 |
+
|
46 |
+
# Try to import model classes
|
47 |
+
try:
|
48 |
+
from .vqvae import VQVAE
|
49 |
+
self.model_registry[ModelType.VQVAE] = VQVAE
|
50 |
+
except ImportError:
|
51 |
+
print("Warning: VQVAE model implementation not found")
|
52 |
+
|
53 |
+
try:
|
54 |
+
from .cnn import CNN
|
55 |
+
self.model_registry[ModelType.SRCNN] = CNN
|
56 |
+
except ImportError:
|
57 |
+
print("Warning: CNN model implementation not found")
|
58 |
+
|
59 |
+
try:
|
60 |
+
from .gan import UncertainESRGAN
|
61 |
+
self.model_registry[ModelType.GAN] = UncertainESRGAN
|
62 |
+
except ImportError:
|
63 |
+
print("Warning: GAN model implementation not found")
|
64 |
+
|
65 |
+
def load_model(self, model_type: str, checkpoint_path: str, config_overrides: Optional[Dict] = None):
|
66 |
+
"""
|
67 |
+
Load a model with its checkpoint and optional configuration overrides
|
68 |
+
|
69 |
+
Args:
|
70 |
+
model_type: Type of model to load ('srcnn', 'gan', or 'vqvae')
|
71 |
+
checkpoint_path: Path to model checkpoint
|
72 |
+
config_overrides: Optional dictionary of configuration overrides
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
Loaded model or None if loading fails
|
76 |
+
"""
|
77 |
+
try:
|
78 |
+
# Convert string to enum
|
79 |
+
model_type = ModelType(model_type.lower())
|
80 |
+
|
81 |
+
# Check if model implementation is available
|
82 |
+
if model_type not in self.model_registry:
|
83 |
+
raise ValueError(f"Model type {model_type.value} is not available")
|
84 |
+
|
85 |
+
# Get base config and apply overrides if provided
|
86 |
+
model_config = self.model_configs[model_type].copy()
|
87 |
+
if config_overrides:
|
88 |
+
model_config.update(config_overrides)
|
89 |
+
|
90 |
+
# Initialize model with potentially modified config
|
91 |
+
model_class = self.model_registry[model_type]
|
92 |
+
model = model_class(**model_config)
|
93 |
+
|
94 |
+
# Move model to device
|
95 |
+
model = model.to(self.device)
|
96 |
+
|
97 |
+
# Load checkpoint
|
98 |
+
try:
|
99 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
100 |
+
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
101 |
+
|
102 |
+
# Load uncertainty tracker state if available
|
103 |
+
if hasattr(model, 'uncertainty_tracker'):
|
104 |
+
model.uncertainty_tracker.calibrated = checkpoint.get('calibrated', False)
|
105 |
+
|
106 |
+
if model.uncertainty_tracker.calibrated:
|
107 |
+
if 'block_scale_means' in checkpoint:
|
108 |
+
model.uncertainty_tracker.block_scale_means = checkpoint['block_scale_means']
|
109 |
+
if 'block_scale_stds' in checkpoint:
|
110 |
+
model.uncertainty_tracker.block_scale_stds = checkpoint['block_scale_stds']
|
111 |
+
|
112 |
+
print(f"Successfully loaded {model_type.value} model from {checkpoint_path}")
|
113 |
+
except Exception as e:
|
114 |
+
print(f"Warning: Could not load checkpoint. Using untrained model. Error: {e}")
|
115 |
+
|
116 |
+
# Store model
|
117 |
+
model.eval()
|
118 |
+
self.loaded_models[model_type] = model
|
119 |
+
|
120 |
+
return model
|
121 |
+
|
122 |
+
except Exception as e:
|
123 |
+
print(f"Error loading model: {str(e)}")
|
124 |
+
return None
|
125 |
+
|
126 |
+
def get_model(self, model_type: str):
|
127 |
+
"""Get a loaded model by type"""
|
128 |
+
try:
|
129 |
+
model_type = ModelType(model_type.lower())
|
130 |
+
return self.loaded_models.get(model_type)
|
131 |
+
except:
|
132 |
+
return None
|
133 |
+
|
134 |
+
def unload_model(self, model_type: str):
|
135 |
+
"""Unload a specific model"""
|
136 |
+
try:
|
137 |
+
model_type = ModelType(model_type.lower())
|
138 |
+
if model_type in self.loaded_models:
|
139 |
+
del self.loaded_models[model_type]
|
140 |
+
torch.cuda.empty_cache()
|
141 |
+
except:
|
142 |
+
pass
|
143 |
+
|
144 |
+
def unload_all_models(self):
|
145 |
+
"""Unload all loaded models"""
|
146 |
+
self.loaded_models.clear()
|
147 |
+
torch.cuda.empty_cache()
|
models/uncertainty.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class BlockUncertaintyTracker(nn.Module):
|
6 |
+
"""
|
7 |
+
Track and estimate uncertainty at block level for bathymetry super-resolution
|
8 |
+
"""
|
9 |
+
def __init__(self, block_size=4, alpha=0.1, decay=0.99, eps=1e-5):
|
10 |
+
"""
|
11 |
+
Initialize block-wise uncertainty tracker
|
12 |
+
|
13 |
+
Args:
|
14 |
+
block_size: Size of spatial blocks for uncertainty estimation
|
15 |
+
alpha: Quantile parameter for uncertainty bounds
|
16 |
+
decay: EMA decay factor for tracking statistics
|
17 |
+
eps: Small value for numerical stability
|
18 |
+
"""
|
19 |
+
super().__init__()
|
20 |
+
self.block_size = block_size
|
21 |
+
self.decay = decay
|
22 |
+
self.alpha = alpha
|
23 |
+
self.eps = eps
|
24 |
+
|
25 |
+
# Initialize unfold layer for block extraction
|
26 |
+
self.unfold = nn.Unfold(kernel_size=block_size, stride=block_size)
|
27 |
+
|
28 |
+
# Register buffers with initial values
|
29 |
+
self.register_buffer('ema_errors', None)
|
30 |
+
self.register_buffer('ema_quantile', None)
|
31 |
+
|
32 |
+
self.num_blocks_h = None
|
33 |
+
self.num_blocks_w = None
|
34 |
+
|
35 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
36 |
+
|
37 |
+
# Calibration statistics
|
38 |
+
self.calibrated = False
|
39 |
+
self.block_means = []
|
40 |
+
self.block_stds = []
|
41 |
+
self.block_scale_means = None
|
42 |
+
self.block_scale_stds = None
|
43 |
+
|
44 |
+
def _initialize_buffers(self, h, w, device):
|
45 |
+
"""Initialize EMA buffers based on number of blocks in image"""
|
46 |
+
self.num_blocks_h = h // self.block_size
|
47 |
+
self.num_blocks_w = w // self.block_size
|
48 |
+
num_blocks = self.num_blocks_h * self.num_blocks_w
|
49 |
+
|
50 |
+
# Initialize buffers on the correct device
|
51 |
+
self.ema_errors = torch.zeros(num_blocks, device=device)
|
52 |
+
self.ema_quantile = torch.zeros(num_blocks, device=device)
|
53 |
+
|
54 |
+
def update(self, current_errors):
|
55 |
+
"""Update EMA of errors and quantiles for each block"""
|
56 |
+
B, C, H, W = current_errors.shape
|
57 |
+
device = current_errors.device
|
58 |
+
|
59 |
+
# Initialize buffers if not done yet
|
60 |
+
if self.ema_errors is None:
|
61 |
+
self._initialize_buffers(H, W, device)
|
62 |
+
|
63 |
+
# Unfold into blocks
|
64 |
+
blocks = self.unfold(current_errors)
|
65 |
+
block_errors = blocks.transpose(1, 2)
|
66 |
+
block_errors = block_errors.reshape(-1, self.num_blocks_h * self.num_blocks_w, self.block_size * self.block_size)
|
67 |
+
|
68 |
+
with torch.no_grad():
|
69 |
+
# Compute mean error per block
|
70 |
+
block_mean_errors = block_errors.mean(dim=-1) # [B, num_blocks]
|
71 |
+
|
72 |
+
# Update EMA errors for each block
|
73 |
+
block_means = block_mean_errors.mean(dim=0) # Average across batch
|
74 |
+
block_means = block_means.to(device) # Ensure on correct device
|
75 |
+
self.ema_errors = self.ema_errors.to(device) # Ensure on correct device
|
76 |
+
self.ema_errors.mul_(self.decay).add_(block_means * (1 - self.decay))
|
77 |
+
|
78 |
+
# Update quantiles for each block
|
79 |
+
block_quantiles = torch.quantile(block_errors, 1 - self.alpha, dim=-1) # [B, num_blocks]
|
80 |
+
quantile_means = block_quantiles.mean(dim=0) # Average across batch
|
81 |
+
quantile_means = quantile_means.to(device) # Ensure on correct device
|
82 |
+
self.ema_quantile = self.ema_quantile.to(device) # Ensure on correct device
|
83 |
+
self.ema_quantile.mul_(self.decay).add_(quantile_means * (1 - self.decay))
|
84 |
+
|
85 |
+
def get_uncertainty(self, errors):
|
86 |
+
"""Calculate block-wise uncertainty scores"""
|
87 |
+
B, C, H, W = errors.shape
|
88 |
+
device = errors.device
|
89 |
+
|
90 |
+
# Initialize buffers if not done yet
|
91 |
+
if self.ema_errors is None:
|
92 |
+
self._initialize_buffers(H, W, device)
|
93 |
+
|
94 |
+
# Ensure buffers are on correct device
|
95 |
+
self.ema_errors = self.ema_errors.to(device)
|
96 |
+
self.ema_quantile = self.ema_quantile.to(device)
|
97 |
+
|
98 |
+
# Unfold into blocks
|
99 |
+
blocks = self.unfold(errors)
|
100 |
+
|
101 |
+
# Calculate uncertainty for each block
|
102 |
+
uncertainties = []
|
103 |
+
for i in range(self.num_blocks_h * self.num_blocks_w):
|
104 |
+
block = blocks[:, :, i].view(B, C, self.block_size, self.block_size)
|
105 |
+
uncertainty = block / (self.ema_quantile[i] + self.eps)
|
106 |
+
uncertainties.append(uncertainty)
|
107 |
+
|
108 |
+
# Reconstruct full image from blocks
|
109 |
+
uncertainty_blocks = torch.stack(uncertainties, dim=-1)
|
110 |
+
uncertainty_blocks = uncertainty_blocks.permute(0, 1, 4, 2, 3)
|
111 |
+
|
112 |
+
# Reshape to original image size
|
113 |
+
uncertainty_map = uncertainty_blocks.reshape(
|
114 |
+
B, C,
|
115 |
+
self.num_blocks_h, self.block_size,
|
116 |
+
self.num_blocks_w, self.block_size
|
117 |
+
).permute(0, 1, 2, 4, 3, 5).reshape(B, C, H, W)
|
118 |
+
|
119 |
+
return uncertainty_map
|
120 |
+
|
121 |
+
def get_bounds(self, x, confidence_level=0.95):
|
122 |
+
"""
|
123 |
+
Get prediction bounds based on calibrated statistics
|
124 |
+
|
125 |
+
Args:
|
126 |
+
x: Input tensor [B, C, H, W]
|
127 |
+
confidence_level: Confidence level for bounds
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
tuple: (lower_bounds, upper_bounds)
|
131 |
+
"""
|
132 |
+
if not self.calibrated:
|
133 |
+
print("Warning: Model not calibrated. Bounds may be inaccurate.")
|
134 |
+
# Return simple bounds based on mean error
|
135 |
+
return x * 0.9, x * 1.1
|
136 |
+
|
137 |
+
# Calculate z-score based on confidence level
|
138 |
+
z_scores = {
|
139 |
+
0.99: 2.576,
|
140 |
+
0.95: 1.96,
|
141 |
+
0.90: 1.645,
|
142 |
+
0.85: 1.440,
|
143 |
+
0.80: 1.282
|
144 |
+
}
|
145 |
+
z_score = z_scores.get(confidence_level, 1.96)
|
146 |
+
|
147 |
+
# Get block-wise uncertainty
|
148 |
+
blocks = self.unfold(x) # [B, C*block_size*block_size, num_blocks]
|
149 |
+
|
150 |
+
# Calculate bounds for each block using calibrated statistics
|
151 |
+
B, C, H, W = x.shape
|
152 |
+
lower_bounds = []
|
153 |
+
upper_bounds = []
|
154 |
+
|
155 |
+
for i in range(self.num_blocks_h * self.num_blocks_w):
|
156 |
+
block = blocks[:, :, i].view(B, C, self.block_size, self.block_size)
|
157 |
+
|
158 |
+
# Use calibrated statistics to determine uncertainty
|
159 |
+
if hasattr(self, 'block_scale_stds') and self.block_scale_stds is not None:
|
160 |
+
uncertainty = z_score * self.block_scale_stds[i]
|
161 |
+
else:
|
162 |
+
# Fallback if calibration stats not available
|
163 |
+
uncertainty = 0.1 * block.mean()
|
164 |
+
|
165 |
+
lower_bound = torch.clamp(block - uncertainty, min=0.0)
|
166 |
+
upper_bound = torch.clamp(block + uncertainty, max=1.0)
|
167 |
+
|
168 |
+
lower_bounds.append(lower_bound)
|
169 |
+
upper_bounds.append(upper_bound)
|
170 |
+
|
171 |
+
# Reconstruct full image from blocks
|
172 |
+
lower_bounds = torch.stack(lower_bounds, dim=-1)
|
173 |
+
upper_bounds = torch.stack(upper_bounds, dim=-1)
|
174 |
+
|
175 |
+
# Reshape to original image size
|
176 |
+
lower_bounds = lower_bounds.permute(0, 1, 4, 2, 3).reshape(
|
177 |
+
B, C,
|
178 |
+
self.num_blocks_h, self.num_blocks_w,
|
179 |
+
self.block_size, self.block_size
|
180 |
+
).permute(0, 1, 2, 4, 3, 5).reshape(B, C, H, W)
|
181 |
+
|
182 |
+
upper_bounds = upper_bounds.permute(0, 1, 4, 2, 3).reshape(
|
183 |
+
B, C,
|
184 |
+
self.num_blocks_h, self.num_blocks_w,
|
185 |
+
self.block_size, self.block_size
|
186 |
+
).permute(0, 1, 2, 4, 3, 5).reshape(B, C, H, W)
|
187 |
+
|
188 |
+
return lower_bounds, upper_bounds
|
models/vqvae.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from .uncertainty import BlockUncertaintyTracker
|
5 |
+
|
6 |
+
class ResidualAttentionBlock(nn.Module):
|
7 |
+
"""Residual attention block for capturing spatial dependencies"""
|
8 |
+
def __init__(self, in_channels):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
# Trunk branch
|
12 |
+
self.trunk = nn.Sequential(
|
13 |
+
nn.ReflectionPad2d(1),
|
14 |
+
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=0),
|
15 |
+
nn.BatchNorm2d(in_channels),
|
16 |
+
nn.SiLU(),
|
17 |
+
nn.ReflectionPad2d(1),
|
18 |
+
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=0),
|
19 |
+
nn.BatchNorm2d(in_channels)
|
20 |
+
)
|
21 |
+
|
22 |
+
# Mask branch for attention
|
23 |
+
self.mask = nn.Sequential(
|
24 |
+
nn.AdaptiveAvgPool2d(1),
|
25 |
+
nn.Conv2d(in_channels, in_channels, kernel_size=1),
|
26 |
+
nn.SiLU(),
|
27 |
+
nn.Conv2d(in_channels, in_channels, kernel_size=1),
|
28 |
+
nn.Sigmoid()
|
29 |
+
)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
# Trunk branch
|
33 |
+
trunk_output = self.trunk(x)
|
34 |
+
|
35 |
+
# Mask branch for attention weights
|
36 |
+
attention = self.mask(x)
|
37 |
+
|
38 |
+
# Apply attention and residual connection
|
39 |
+
out = x + attention * trunk_output
|
40 |
+
return F.silu(out)
|
41 |
+
|
42 |
+
class VectorQuantizer(nn.Module):
|
43 |
+
"""Vector quantizer for discrete latent representation"""
|
44 |
+
def __init__(self, n_embeddings=512, embedding_dim=256, beta=0.25):
|
45 |
+
super().__init__()
|
46 |
+
self.n_embeddings = n_embeddings
|
47 |
+
self.embedding_dim = embedding_dim
|
48 |
+
self.beta = beta
|
49 |
+
|
50 |
+
# Initialize embeddings
|
51 |
+
self.embeddings = nn.Parameter(torch.randn(n_embeddings, embedding_dim))
|
52 |
+
nn.init.uniform_(self.embeddings, -1.0 / n_embeddings, 1.0 / n_embeddings)
|
53 |
+
|
54 |
+
# Usage tracking
|
55 |
+
self.register_buffer('usage', torch.zeros(n_embeddings))
|
56 |
+
|
57 |
+
def forward(self, z):
|
58 |
+
# Reshape input for quantization
|
59 |
+
z_flattened = z.reshape(-1, self.embedding_dim)
|
60 |
+
|
61 |
+
# Calculate distances to embedding vectors
|
62 |
+
distances = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
|
63 |
+
torch.sum(self.embeddings**2, dim=1) - \
|
64 |
+
2 * torch.matmul(z_flattened, self.embeddings.t())
|
65 |
+
|
66 |
+
# Find nearest embedding for each input vector
|
67 |
+
encoding_indices = torch.argmin(distances, dim=1)
|
68 |
+
|
69 |
+
# Update usage statistics
|
70 |
+
if self.training:
|
71 |
+
with torch.no_grad():
|
72 |
+
usage = torch.zeros_like(self.usage)
|
73 |
+
usage.scatter_add_(0, encoding_indices, torch.ones_like(encoding_indices, dtype=torch.float))
|
74 |
+
self.usage.mul_(0.99).add_(usage, alpha=0.01)
|
75 |
+
|
76 |
+
# Get quantized vectors
|
77 |
+
z_q = self.embeddings[encoding_indices].reshape(z.shape)
|
78 |
+
|
79 |
+
# Calculate loss terms
|
80 |
+
commitment_loss = F.mse_loss(z_q.detach(), z)
|
81 |
+
codebook_loss = F.mse_loss(z_q, z.detach())
|
82 |
+
|
83 |
+
# Combine losses
|
84 |
+
loss = codebook_loss + self.beta * commitment_loss
|
85 |
+
|
86 |
+
# Straight-through estimator
|
87 |
+
z_q = z + (z_q - z).detach()
|
88 |
+
|
89 |
+
if self.training:
|
90 |
+
return z_q, loss
|
91 |
+
else:
|
92 |
+
return z_q
|
93 |
+
|
94 |
+
class Encoder(nn.Module):
|
95 |
+
"""Encoder for VQ-VAE model"""
|
96 |
+
def __init__(self, in_channels=1, hidden_dims=[32, 64, 128, 256], embedding_dim=256):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
# Initial conv layer
|
100 |
+
layers = [
|
101 |
+
nn.Conv2d(in_channels, hidden_dims[0], kernel_size=3, stride=1, padding=1),
|
102 |
+
nn.BatchNorm2d(hidden_dims[0]),
|
103 |
+
nn.SiLU()
|
104 |
+
]
|
105 |
+
|
106 |
+
# Hidden layers with downsampling
|
107 |
+
for i in range(len(hidden_dims) - 1):
|
108 |
+
layers.extend([
|
109 |
+
nn.Conv2d(hidden_dims[i], hidden_dims[i+1], kernel_size=4, stride=2, padding=1),
|
110 |
+
nn.BatchNorm2d(hidden_dims[i+1]),
|
111 |
+
nn.SiLU()
|
112 |
+
])
|
113 |
+
|
114 |
+
# Residual attention blocks
|
115 |
+
for _ in range(2):
|
116 |
+
layers.append(ResidualAttentionBlock(hidden_dims[-1]))
|
117 |
+
|
118 |
+
# Final projection to embedding dimension
|
119 |
+
layers.extend([
|
120 |
+
nn.Conv2d(hidden_dims[-1], embedding_dim, kernel_size=1),
|
121 |
+
nn.BatchNorm2d(embedding_dim)
|
122 |
+
])
|
123 |
+
|
124 |
+
self.encoder = nn.Sequential(*layers)
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
return self.encoder(x)
|
128 |
+
|
129 |
+
class Decoder(nn.Module):
|
130 |
+
"""Decoder for VQ-VAE model"""
|
131 |
+
def __init__(self, embedding_dim=256, hidden_dims=[256, 128, 64, 32], out_channels=1):
|
132 |
+
super().__init__()
|
133 |
+
|
134 |
+
# Reverse hidden dims for decoder
|
135 |
+
hidden_dims = hidden_dims[::-1]
|
136 |
+
|
137 |
+
# Initial processing
|
138 |
+
layers = [
|
139 |
+
nn.Conv2d(embedding_dim, hidden_dims[0], kernel_size=3, stride=1, padding=1),
|
140 |
+
nn.BatchNorm2d(hidden_dims[0]),
|
141 |
+
nn.SiLU()
|
142 |
+
]
|
143 |
+
|
144 |
+
# Residual attention blocks
|
145 |
+
for _ in range(2):
|
146 |
+
layers.append(ResidualAttentionBlock(hidden_dims[0]))
|
147 |
+
|
148 |
+
# Upsampling blocks
|
149 |
+
for i in range(len(hidden_dims) - 1):
|
150 |
+
layers.extend([
|
151 |
+
nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i+1],
|
152 |
+
kernel_size=4, stride=2, padding=1),
|
153 |
+
nn.BatchNorm2d(hidden_dims[i+1]),
|
154 |
+
nn.SiLU()
|
155 |
+
])
|
156 |
+
|
157 |
+
# Final output layer
|
158 |
+
layers.append(
|
159 |
+
nn.Conv2d(hidden_dims[-1], out_channels, kernel_size=3, padding=1)
|
160 |
+
)
|
161 |
+
layers.append(nn.Sigmoid())
|
162 |
+
|
163 |
+
self.decoder = nn.Sequential(*layers)
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
return self.decoder(x)
|
167 |
+
|
168 |
+
class VQVAE(nn.Module):
|
169 |
+
"""
|
170 |
+
Vector Quantized Variational Autoencoder with uncertainty awareness
|
171 |
+
for bathymetry super-resolution
|
172 |
+
"""
|
173 |
+
def __init__(self, in_channels=1, hidden_dims=[32, 64, 128, 256],
|
174 |
+
num_embeddings=512, embedding_dim=256, block_size=4, alpha=0.1):
|
175 |
+
super().__init__()
|
176 |
+
|
177 |
+
# Initialize block-wise uncertainty tracking
|
178 |
+
self.uncertainty_tracker = BlockUncertaintyTracker(
|
179 |
+
block_size=block_size,
|
180 |
+
alpha=alpha,
|
181 |
+
decay=0.99,
|
182 |
+
eps=1e-5
|
183 |
+
)
|
184 |
+
|
185 |
+
# Main model components
|
186 |
+
self.encoder = Encoder(
|
187 |
+
in_channels=in_channels,
|
188 |
+
hidden_dims=hidden_dims,
|
189 |
+
embedding_dim=embedding_dim
|
190 |
+
)
|
191 |
+
|
192 |
+
self.vq = VectorQuantizer(
|
193 |
+
n_embeddings=num_embeddings,
|
194 |
+
embedding_dim=embedding_dim,
|
195 |
+
beta=0.25
|
196 |
+
)
|
197 |
+
|
198 |
+
self.decoder = Decoder(
|
199 |
+
embedding_dim=embedding_dim,
|
200 |
+
hidden_dims=hidden_dims,
|
201 |
+
out_channels=in_channels
|
202 |
+
)
|
203 |
+
|
204 |
+
def forward(self, x):
|
205 |
+
"""Forward pass through the model"""
|
206 |
+
# Encode
|
207 |
+
z = self.encoder(x)
|
208 |
+
|
209 |
+
# Vector quantization
|
210 |
+
if self.training:
|
211 |
+
z_q, vq_loss = self.vq(z)
|
212 |
+
|
213 |
+
# Decode
|
214 |
+
reconstruction = self.decoder(z_q)
|
215 |
+
|
216 |
+
return reconstruction, vq_loss
|
217 |
+
else:
|
218 |
+
z_q = self.vq(z)
|
219 |
+
|
220 |
+
# Decode
|
221 |
+
reconstruction = self.decoder(z_q)
|
222 |
+
|
223 |
+
return reconstruction
|
224 |
+
|
225 |
+
def train_forward(self, x, y):
|
226 |
+
"""Training forward pass with uncertainty tracking"""
|
227 |
+
# Get reconstruction and VQ loss
|
228 |
+
reconstruction, vq_loss = self.forward(x)
|
229 |
+
|
230 |
+
# Calculate reconstruction error
|
231 |
+
error = torch.abs(reconstruction - y)
|
232 |
+
|
233 |
+
# Update uncertainty tracker
|
234 |
+
self.uncertainty_tracker.update(error)
|
235 |
+
|
236 |
+
# Get uncertainty map for loss weighting
|
237 |
+
uncertainty_map = self.uncertainty_tracker.get_uncertainty(error)
|
238 |
+
|
239 |
+
return reconstruction, vq_loss, uncertainty_map
|
240 |
+
|
241 |
+
def predict_with_uncertainty(self, x, confidence_level=0.95):
|
242 |
+
"""
|
243 |
+
Forward pass with calibrated uncertainty bounds
|
244 |
+
|
245 |
+
Args:
|
246 |
+
x: Input tensor
|
247 |
+
confidence_level: Confidence level for bounds (default: 0.95)
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
tuple: (reconstruction, lower_bounds, upper_bounds)
|
251 |
+
"""
|
252 |
+
self.eval()
|
253 |
+
with torch.no_grad():
|
254 |
+
# Get reconstruction
|
255 |
+
reconstruction = self.forward(x)
|
256 |
+
|
257 |
+
# Get calibrated uncertainty bounds
|
258 |
+
lower_bounds, upper_bounds = self.uncertainty_tracker.get_bounds(
|
259 |
+
reconstruction, confidence_level
|
260 |
+
)
|
261 |
+
|
262 |
+
return reconstruction, lower_bounds, upper_bounds
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
numpy>=1.20.0
|
3 |
+
matplotlib>=3.5.0
|
4 |
+
gradio>=3.32.0
|
5 |
+
Pillow>=9.0.0
|
6 |
+
scipy>=1.8.0
|
7 |
+
tqdm>=4.62.0
|
8 |
+
huggingface_hub>=0.14.0
|
9 |
+
transformers>=4.30.0
|
10 |
+
pandas>=1.3.0
|
11 |
+
scikit-learn>=1.0.0
|