import os import torch import numpy as np from PIL import Image import torch.nn.functional as F import json # Import your model components from models.loader import ModelLoader from models.uncertainty import BlockUncertaintyTracker class BathymetrySuperResolution: """ Bathymetry super-resolution model with uncertainty estimation """ def __init__(self, model_type="vqvae", checkpoint_path=None, config_path=None): """ Initialize the super-resolution model with uncertainty awareness Args: model_type: Type of model ('srcnn', 'gan', or 'vqvae') checkpoint_path: Path to model checkpoint config_path: Path to configuration file """ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load config if provided if config_path is not None and os.path.exists(config_path): with open(config_path, 'r') as f: self.config = json.load(f) else: # Default configuration self.config = { "model_type": model_type, "model_config": { "in_channels": 1, "hidden_dims": [32, 64, 128, 256], "num_embeddings": 512, "embedding_dim": 256, "block_size": 4 }, "normalization": { "mean": -3911.3894, "std": 1172.8374, "min": 0.0, "max": 1.0 } } # Initialize model loader self.model_loader = ModelLoader() # Load model if checkpoint_path is not None and os.path.exists(checkpoint_path): self.model = self.model_loader.load_model( self.config['model_type'], checkpoint_path, config_overrides=self.config.get('model_config', {}) ) else: raise ValueError("Checkpoint path not provided or invalid") # Ensure model is in eval mode self.model.eval() # Load normalization parameters self.mean = self.config['normalization']['mean'] self.std = self.config['normalization']['std'] self.min_val = self.config['normalization']['min'] self.max_val = self.config['normalization']['max'] def preprocess(self, data): """ Preprocess input data for the model Args: data: Input array/image (can be numpy array, PIL Image, or tensor) Returns: Preprocessed tensor """ # Convert PIL Image to numpy if needed if isinstance(data, Image.Image): data = np.array(data) # Convert numpy to tensor if needed if isinstance(data, np.ndarray): tensor = torch.from_numpy(data).float() else: tensor = data.float() # Add batch and channel dimensions if needed if len(tensor.shape) == 2: tensor = tensor.unsqueeze(0).unsqueeze(0) elif len(tensor.shape) == 3: tensor = tensor.unsqueeze(0) # Apply normalization tensor = (tensor - self.mean) / (self.std + 1e-8) tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min() + 1e-8) # Resize if needed (to 32x32) if tensor.shape[-1] != 32 or tensor.shape[-2] != 32: tensor = F.interpolate( tensor, size=(32, 32), mode='bicubic', align_corners=False ) return tensor.to(self.device) def denormalize(self, tensor): """ Denormalize output tensor Args: tensor: Output tensor from model Returns: Denormalized tensor in original data range """ # Scale from [0,1] back to original range tensor = tensor * (self.max_val - self.min_val) + self.min_val # Restore original scale tensor = tensor * self.std + self.mean return tensor def predict(self, data, with_uncertainty=True, confidence_level=0.95): """ Generate super-resolution output with uncertainty bounds Args: data: Input data (can be numpy array, PIL Image, or tensor) with_uncertainty: Whether to include uncertainty bounds confidence_level: Confidence level for uncertainty bounds Returns: Tuple of (prediction, lower_bound, upper_bound) if with_uncertainty=True or just prediction otherwise """ # Preprocess input input_tensor = self.preprocess(data) with torch.no_grad(): # Run model inference if with_uncertainty and hasattr(self.model, 'predict_with_uncertainty'): prediction, lower_bound, upper_bound = self.model.predict_with_uncertainty( input_tensor, confidence_level ) # Denormalize outputs prediction = self.denormalize(prediction) lower_bound = self.denormalize(lower_bound) if lower_bound is not None else None upper_bound = self.denormalize(upper_bound) if upper_bound is not None else None # Convert to numpy prediction = prediction.cpu().numpy() lower_bound = lower_bound.cpu().numpy() if lower_bound is not None else None upper_bound = upper_bound.cpu().numpy() if upper_bound is not None else None return prediction, lower_bound, upper_bound else: # Standard inference prediction = self.model(input_tensor) # Denormalize prediction = self.denormalize(prediction) # Convert to numpy prediction = prediction.cpu().numpy() return prediction def load_npy(self, file_path): """ Load bathymetry data from numpy file Args: file_path: Path to .npy file Returns: Numpy array containing bathymetry data """ try: return np.load(file_path) except Exception as e: raise ValueError(f"Error loading numpy file: {str(e)}") @staticmethod def get_uncertainty_width(lower_bound, upper_bound): """ Calculate uncertainty width (difference between upper and lower bounds) Args: lower_bound: Lower uncertainty bound upper_bound: Upper uncertainty bound Returns: Uncertainty width """ if lower_bound is None or upper_bound is None: return None return np.mean(upper_bound - lower_bound)