cloud-detection / model.py
truthdotphd's picture
Upload model.py
e47d31c verified
raw
history blame
13.1 kB
import numpy as np
import triton_python_backend_utils as pb_utils
from omnicloudmask import predict_from_array
import rasterio
from rasterio.io import MemoryFile
from rasterio.enums import Resampling
import tempfile
import os
from io import BytesIO
class TritonPythonModel:
def initialize(self, args):
"""
Initialize the model. This function is called once when the model is loaded.
"""
print('Initialized Cloud Detection model with JP2 input and robust GDAL handling')
def safe_read_jp2_bytes(self, jp2_bytes):
"""
Safely read JP2 bytes with multiple fallback methods
"""
try:
# Method 1: Try direct MemoryFile approach (works if GDAL drivers are properly configured)
with MemoryFile(jp2_bytes) as memfile:
with memfile.open() as src:
data = src.read(1).astype(np.float32)
height, width = src.height, src.width
profile = src.profile
return data, height, width, profile
except Exception as e1:
print(f"Method 1 (MemoryFile) failed: {e1}")
try:
# Method 2: Write to temporary file and read from disk
with tempfile.NamedTemporaryFile(delete=False, suffix='.jp2') as tmp_file:
tmp_file.write(jp2_bytes)
tmp_file.flush()
with rasterio.open(tmp_file.name) as src:
data = src.read(1).astype(np.float32)
height, width = src.height, src.width
profile = src.profile
# Clean up temporary file
os.unlink(tmp_file.name)
return data, height, width, profile
except Exception as e2:
print(f"Method 2 (temporary file) failed: {e2}")
try:
# Method 3: Try with different suffix and basic profile
with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
tmp_file.write(jp2_bytes)
tmp_file.flush()
with rasterio.open(tmp_file.name) as src:
data = src.read(1).astype(np.float32)
height, width = src.height, src.width
profile = {'driver': 'GTiff', 'height': height, 'width': width, 'count': 1, 'dtype': 'float32'}
os.unlink(tmp_file.name)
return data, height, width, profile
except Exception as e3:
print(f"Method 3 (tiff fallback) failed: {e3}")
# Method 4: Final fallback - try to interpret as raw numpy array
try:
# This assumes the data might be raw numpy bytes as fallback
data_array = np.frombuffer(jp2_bytes, dtype=np.float32)
# Try to guess square dimensions
side_length = int(np.sqrt(len(data_array)))
if side_length * side_length == len(data_array):
data = data_array.reshape(side_length, side_length)
height, width = side_length, side_length
profile = {'driver': 'GTiff', 'height': height, 'width': width, 'count': 1, 'dtype': 'float32'}
return data, height, width, profile
else:
# Try common satellite image dimensions
common_dims = [(10980, 10980), (5490, 5490), (1024, 1024), (512, 512)]
for h, w in common_dims:
if h * w == len(data_array):
data = data_array.reshape(h, w)
height, width = h, w
profile = {'driver': 'GTiff', 'height': height, 'width': width, 'count': 1, 'dtype': 'float32'}
return data, height, width, profile
raise ValueError(f"Cannot interpret data array of length {len(data_array)} as image")
except Exception as e4:
raise Exception(f"All fallback methods failed: MemoryFile({e1}), TempFile({e2}), TiffFallback({e3}), RawBytes({e4})")
def safe_resample_data(self, data, current_height, current_width, target_height, target_width, profile):
"""
Safely resample data to target dimensions with fallback methods
"""
if current_height == target_height and current_width == target_width:
return data
try:
# Method 1: Use rasterio resampling
temp_profile = profile.copy()
temp_profile.update({
'height': current_height,
'width': current_width,
'count': 1,
'dtype': 'float32'
})
with MemoryFile() as memfile:
with memfile.open(**temp_profile) as temp_dataset:
temp_dataset.write(data, 1)
resampled = temp_dataset.read(
out_shape=(1, target_height, target_width),
resampling=Resampling.bilinear
)[0].astype(np.float32)
return resampled
except Exception as e1:
print(f"Rasterio resampling failed: {e1}")
try:
# Method 2: Use scipy if available
from scipy import ndimage
zoom_factors = (target_height / current_height, target_width / current_width)
resampled = ndimage.zoom(data, zoom_factors, order=1)
return resampled.astype(np.float32)
except ImportError:
print("Scipy not available for resampling")
# Method 3: Simple nearest-neighbor resampling
h_indices = np.round(np.linspace(0, current_height - 1, target_height)).astype(int)
w_indices = np.round(np.linspace(0, current_width - 1, target_width)).astype(int)
resampled = data[np.ix_(h_indices, w_indices)]
return resampled.astype(np.float32)
except Exception as e2:
print(f"Scipy resampling failed: {e2}")
# Method 3: Simple nearest-neighbor resampling
h_indices = np.round(np.linspace(0, current_height - 1, target_height)).astype(int)
w_indices = np.round(np.linspace(0, current_width - 1, target_width)).astype(int)
resampled = data[np.ix_(h_indices, w_indices)]
return resampled.astype(np.float32)
def execute(self, requests):
"""
Process inference requests with robust error handling.
"""
responses = []
for request in requests:
try:
input_tensor = pb_utils.get_input_tensor_by_name(request, "input_jp2_bytes")
jp2_bytes_list = input_tensor.as_numpy()
if len(jp2_bytes_list) != 3:
error_msg = f"Expected 3 JP2 byte strings, received {len(jp2_bytes_list)}"
error = pb_utils.TritonError(error_msg)
response = pb_utils.InferenceResponse(output_tensors=[], error=error)
responses.append(response)
continue
# The input might be hex strings, decode them to bytes first
red_hex = jp2_bytes_list[0]
green_hex = jp2_bytes_list[1]
nir_hex = jp2_bytes_list[2]
# Convert hex strings to bytes
try:
if isinstance(red_hex, str):
red_bytes = bytes.fromhex(red_hex)
green_bytes = bytes.fromhex(green_hex)
nir_bytes = bytes.fromhex(nir_hex)
print(f"Decoded hex strings to bytes")
elif isinstance(red_hex, (bytes, np.bytes_)):
# Already bytes, use directly
red_bytes = bytes(red_hex)
green_bytes = bytes(green_hex)
nir_bytes = bytes(nir_hex)
print(f"Input already in bytes format")
else:
# Might be numpy string object
red_bytes = bytes.fromhex(str(red_hex))
green_bytes = bytes.fromhex(str(green_hex))
nir_bytes = bytes.fromhex(str(nir_hex))
print(f"Converted numpy strings to bytes")
except Exception as e:
error_msg = f"Failed to decode input data: {str(e)}"
print(f"Decode error: {error_msg}")
error = pb_utils.TritonError(error_msg)
response = pb_utils.InferenceResponse(output_tensors=[], error=error)
responses.append(response)
continue
print(f"Processing JP2 data - decoded sizes: Red={len(red_bytes)}, Green={len(green_bytes)}, NIR={len(nir_bytes)}")
# Read red band data (use as reference for dimensions)
red_data, target_height, target_width, red_profile = self.safe_read_jp2_bytes(red_bytes)
print(f"Red band: {red_data.shape}, target dimensions: {target_height}x{target_width}")
# Read and resample green band
green_data, green_height, green_width, green_profile = self.safe_read_jp2_bytes(green_bytes)
green_data = self.safe_resample_data(green_data, green_height, green_width, target_height, target_width, green_profile)
print(f"Green band after resampling: {green_data.shape}")
# Read and resample NIR band
nir_data, nir_height, nir_width, nir_profile = self.safe_read_jp2_bytes(nir_bytes)
nir_data = self.safe_resample_data(nir_data, nir_height, nir_width, target_height, target_width, nir_profile)
print(f"NIR band after resampling: {nir_data.shape}")
# Verify all bands have the same shape
if not (red_data.shape == green_data.shape == nir_data.shape):
shapes = [red_data.shape, green_data.shape, nir_data.shape]
error_msg = f"Band shape mismatch after resampling: {shapes}"
error = pb_utils.TritonError(error_msg)
response = pb_utils.InferenceResponse(output_tensors=[], error=error)
responses.append(response)
continue
# Check for valid dimensions
if red_data.shape[0] == 0 or red_data.shape[1] == 0:
error_msg = f"Invalid band dimensions: {red_data.shape}"
error = pb_utils.TritonError(error_msg)
response = pb_utils.InferenceResponse(output_tensors=[], error=error)
responses.append(response)
continue
# Stack bands in CHW format for prediction (channels, height, width)
prediction_array = np.stack([red_data, green_data, nir_data], axis=0)
print(f"Final prediction array shape: {prediction_array.shape}")
# Run cloud detection prediction
cloud_mask = predict_from_array(prediction_array)
print(f"Cloud mask shape: {cloud_mask.shape}")
# Flatten the mask for output
if cloud_mask.ndim > 1:
cloud_mask = cloud_mask.flatten()
# Create output tensor (config expects TYPE_UINT8)
output_tensor = pb_utils.Tensor("output_mask", cloud_mask.astype(np.uint8))
response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
responses.append(response)
except Exception as e:
# Enhanced error reporting
error_msg = f"Error processing JP2 data: {str(e)}"
print(f"Model execution error: {error_msg}")
error = pb_utils.TritonError(error_msg)
response = pb_utils.InferenceResponse(output_tensors=[], error=error)
responses.append(response)
return responses
def finalize(self):
"""
Clean up when the model is unloaded.
"""
print('Cloud Detection model finalized')