|
import numpy as np |
|
import triton_python_backend_utils as pb_utils |
|
from omnicloudmask import predict_from_array |
|
import rasterio |
|
from rasterio.io import MemoryFile |
|
from rasterio.enums import Resampling |
|
import tempfile |
|
import os |
|
from io import BytesIO |
|
|
|
class TritonPythonModel: |
|
def initialize(self, args): |
|
""" |
|
Initialize the model. This function is called once when the model is loaded. |
|
""" |
|
print('Initialized Cloud Detection model with JP2 input and robust GDAL handling') |
|
|
|
def safe_read_jp2_bytes(self, jp2_bytes): |
|
""" |
|
Safely read JP2 bytes with multiple fallback methods |
|
""" |
|
try: |
|
|
|
with MemoryFile(jp2_bytes) as memfile: |
|
with memfile.open() as src: |
|
data = src.read(1).astype(np.float32) |
|
height, width = src.height, src.width |
|
profile = src.profile |
|
return data, height, width, profile |
|
|
|
except Exception as e1: |
|
print(f"Method 1 (MemoryFile) failed: {e1}") |
|
try: |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.jp2') as tmp_file: |
|
tmp_file.write(jp2_bytes) |
|
tmp_file.flush() |
|
|
|
with rasterio.open(tmp_file.name) as src: |
|
data = src.read(1).astype(np.float32) |
|
height, width = src.height, src.width |
|
profile = src.profile |
|
|
|
|
|
os.unlink(tmp_file.name) |
|
return data, height, width, profile |
|
|
|
except Exception as e2: |
|
print(f"Method 2 (temporary file) failed: {e2}") |
|
try: |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file: |
|
tmp_file.write(jp2_bytes) |
|
tmp_file.flush() |
|
|
|
with rasterio.open(tmp_file.name) as src: |
|
data = src.read(1).astype(np.float32) |
|
height, width = src.height, src.width |
|
profile = {'driver': 'GTiff', 'height': height, 'width': width, 'count': 1, 'dtype': 'float32'} |
|
|
|
os.unlink(tmp_file.name) |
|
return data, height, width, profile |
|
|
|
except Exception as e3: |
|
print(f"Method 3 (tiff fallback) failed: {e3}") |
|
|
|
try: |
|
|
|
data_array = np.frombuffer(jp2_bytes, dtype=np.float32) |
|
|
|
|
|
side_length = int(np.sqrt(len(data_array))) |
|
if side_length * side_length == len(data_array): |
|
data = data_array.reshape(side_length, side_length) |
|
height, width = side_length, side_length |
|
profile = {'driver': 'GTiff', 'height': height, 'width': width, 'count': 1, 'dtype': 'float32'} |
|
return data, height, width, profile |
|
else: |
|
|
|
common_dims = [(10980, 10980), (5490, 5490), (1024, 1024), (512, 512)] |
|
for h, w in common_dims: |
|
if h * w == len(data_array): |
|
data = data_array.reshape(h, w) |
|
height, width = h, w |
|
profile = {'driver': 'GTiff', 'height': height, 'width': width, 'count': 1, 'dtype': 'float32'} |
|
return data, height, width, profile |
|
|
|
raise ValueError(f"Cannot interpret data array of length {len(data_array)} as image") |
|
|
|
except Exception as e4: |
|
raise Exception(f"All fallback methods failed: MemoryFile({e1}), TempFile({e2}), TiffFallback({e3}), RawBytes({e4})") |
|
|
|
def safe_resample_data(self, data, current_height, current_width, target_height, target_width, profile): |
|
""" |
|
Safely resample data to target dimensions with fallback methods |
|
""" |
|
if current_height == target_height and current_width == target_width: |
|
return data |
|
|
|
try: |
|
|
|
temp_profile = profile.copy() |
|
temp_profile.update({ |
|
'height': current_height, |
|
'width': current_width, |
|
'count': 1, |
|
'dtype': 'float32' |
|
}) |
|
|
|
with MemoryFile() as memfile: |
|
with memfile.open(**temp_profile) as temp_dataset: |
|
temp_dataset.write(data, 1) |
|
|
|
resampled = temp_dataset.read( |
|
out_shape=(1, target_height, target_width), |
|
resampling=Resampling.bilinear |
|
)[0].astype(np.float32) |
|
|
|
return resampled |
|
|
|
except Exception as e1: |
|
print(f"Rasterio resampling failed: {e1}") |
|
try: |
|
|
|
from scipy import ndimage |
|
zoom_factors = (target_height / current_height, target_width / current_width) |
|
resampled = ndimage.zoom(data, zoom_factors, order=1) |
|
return resampled.astype(np.float32) |
|
|
|
except ImportError: |
|
print("Scipy not available for resampling") |
|
|
|
h_indices = np.round(np.linspace(0, current_height - 1, target_height)).astype(int) |
|
w_indices = np.round(np.linspace(0, current_width - 1, target_width)).astype(int) |
|
|
|
resampled = data[np.ix_(h_indices, w_indices)] |
|
return resampled.astype(np.float32) |
|
|
|
except Exception as e2: |
|
print(f"Scipy resampling failed: {e2}") |
|
|
|
h_indices = np.round(np.linspace(0, current_height - 1, target_height)).astype(int) |
|
w_indices = np.round(np.linspace(0, current_width - 1, target_width)).astype(int) |
|
|
|
resampled = data[np.ix_(h_indices, w_indices)] |
|
return resampled.astype(np.float32) |
|
|
|
def execute(self, requests): |
|
""" |
|
Process inference requests with robust error handling. |
|
""" |
|
responses = [] |
|
|
|
for request in requests: |
|
try: |
|
input_tensor = pb_utils.get_input_tensor_by_name(request, "input_jp2_bytes") |
|
jp2_bytes_list = input_tensor.as_numpy() |
|
|
|
if len(jp2_bytes_list) != 3: |
|
error_msg = f"Expected 3 JP2 byte strings, received {len(jp2_bytes_list)}" |
|
error = pb_utils.TritonError(error_msg) |
|
response = pb_utils.InferenceResponse(output_tensors=[], error=error) |
|
responses.append(response) |
|
continue |
|
|
|
|
|
red_hex = jp2_bytes_list[0] |
|
green_hex = jp2_bytes_list[1] |
|
nir_hex = jp2_bytes_list[2] |
|
|
|
|
|
try: |
|
if isinstance(red_hex, str): |
|
red_bytes = bytes.fromhex(red_hex) |
|
green_bytes = bytes.fromhex(green_hex) |
|
nir_bytes = bytes.fromhex(nir_hex) |
|
print(f"Decoded hex strings to bytes") |
|
elif isinstance(red_hex, (bytes, np.bytes_)): |
|
|
|
red_bytes = bytes(red_hex) |
|
green_bytes = bytes(green_hex) |
|
nir_bytes = bytes(nir_hex) |
|
print(f"Input already in bytes format") |
|
else: |
|
|
|
red_bytes = bytes.fromhex(str(red_hex)) |
|
green_bytes = bytes.fromhex(str(green_hex)) |
|
nir_bytes = bytes.fromhex(str(nir_hex)) |
|
print(f"Converted numpy strings to bytes") |
|
except Exception as e: |
|
error_msg = f"Failed to decode input data: {str(e)}" |
|
print(f"Decode error: {error_msg}") |
|
error = pb_utils.TritonError(error_msg) |
|
response = pb_utils.InferenceResponse(output_tensors=[], error=error) |
|
responses.append(response) |
|
continue |
|
|
|
print(f"Processing JP2 data - decoded sizes: Red={len(red_bytes)}, Green={len(green_bytes)}, NIR={len(nir_bytes)}") |
|
|
|
|
|
red_data, target_height, target_width, red_profile = self.safe_read_jp2_bytes(red_bytes) |
|
print(f"Red band: {red_data.shape}, target dimensions: {target_height}x{target_width}") |
|
|
|
|
|
green_data, green_height, green_width, green_profile = self.safe_read_jp2_bytes(green_bytes) |
|
green_data = self.safe_resample_data(green_data, green_height, green_width, target_height, target_width, green_profile) |
|
print(f"Green band after resampling: {green_data.shape}") |
|
|
|
|
|
nir_data, nir_height, nir_width, nir_profile = self.safe_read_jp2_bytes(nir_bytes) |
|
nir_data = self.safe_resample_data(nir_data, nir_height, nir_width, target_height, target_width, nir_profile) |
|
print(f"NIR band after resampling: {nir_data.shape}") |
|
|
|
|
|
if not (red_data.shape == green_data.shape == nir_data.shape): |
|
shapes = [red_data.shape, green_data.shape, nir_data.shape] |
|
error_msg = f"Band shape mismatch after resampling: {shapes}" |
|
error = pb_utils.TritonError(error_msg) |
|
response = pb_utils.InferenceResponse(output_tensors=[], error=error) |
|
responses.append(response) |
|
continue |
|
|
|
|
|
if red_data.shape[0] == 0 or red_data.shape[1] == 0: |
|
error_msg = f"Invalid band dimensions: {red_data.shape}" |
|
error = pb_utils.TritonError(error_msg) |
|
response = pb_utils.InferenceResponse(output_tensors=[], error=error) |
|
responses.append(response) |
|
continue |
|
|
|
|
|
prediction_array = np.stack([red_data, green_data, nir_data], axis=0) |
|
print(f"Final prediction array shape: {prediction_array.shape}") |
|
|
|
|
|
cloud_mask = predict_from_array(prediction_array) |
|
print(f"Cloud mask shape: {cloud_mask.shape}") |
|
|
|
|
|
if cloud_mask.ndim > 1: |
|
cloud_mask = cloud_mask.flatten() |
|
|
|
|
|
output_tensor = pb_utils.Tensor("output_mask", cloud_mask.astype(np.uint8)) |
|
response = pb_utils.InferenceResponse(output_tensors=[output_tensor]) |
|
responses.append(response) |
|
|
|
except Exception as e: |
|
|
|
error_msg = f"Error processing JP2 data: {str(e)}" |
|
print(f"Model execution error: {error_msg}") |
|
error = pb_utils.TritonError(error_msg) |
|
response = pb_utils.InferenceResponse(output_tensors=[], error=error) |
|
responses.append(response) |
|
|
|
return responses |
|
|
|
def finalize(self): |
|
""" |
|
Clean up when the model is unloaded. |
|
""" |
|
print('Cloud Detection model finalized') |