File size: 13,085 Bytes
36d88d2
 
 
024f5b3
 
 
d03e9df
 
 
36d88d2
 
 
024f5b3
 
 
d03e9df
36d88d2
d03e9df
024f5b3
d03e9df
024f5b3
d03e9df
 
 
 
 
 
 
 
 
 
 
024f5b3
d03e9df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e47d31c
d03e9df
 
 
 
 
 
 
 
 
024f5b3
d03e9df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
024f5b3
d03e9df
 
 
 
 
024f5b3
d03e9df
 
 
 
 
 
 
 
 
 
 
 
e47d31c
 
 
 
d03e9df
e47d31c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d03e9df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e47d31c
 
 
 
 
 
 
 
d03e9df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
024f5b3
 
d03e9df
 
 
 
024f5b3
d03e9df
024f5b3
36d88d2
 
 
024f5b3
d03e9df
024f5b3
d03e9df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import numpy as np
import triton_python_backend_utils as pb_utils
from omnicloudmask import predict_from_array
import rasterio
from rasterio.io import MemoryFile
from rasterio.enums import Resampling
import tempfile
import os
from io import BytesIO

class TritonPythonModel:
    def initialize(self, args):
        """
        Initialize the model. This function is called once when the model is loaded.
        """
        print('Initialized Cloud Detection model with JP2 input and robust GDAL handling')

    def safe_read_jp2_bytes(self, jp2_bytes):
        """
        Safely read JP2 bytes with multiple fallback methods
        """
        try:
            # Method 1: Try direct MemoryFile approach (works if GDAL drivers are properly configured)
            with MemoryFile(jp2_bytes) as memfile:
                with memfile.open() as src:
                    data = src.read(1).astype(np.float32)
                    height, width = src.height, src.width
                    profile = src.profile
                    return data, height, width, profile
                    
        except Exception as e1:
            print(f"Method 1 (MemoryFile) failed: {e1}")
            try:
                # Method 2: Write to temporary file and read from disk
                with tempfile.NamedTemporaryFile(delete=False, suffix='.jp2') as tmp_file:
                    tmp_file.write(jp2_bytes)
                    tmp_file.flush()
                    
                    with rasterio.open(tmp_file.name) as src:
                        data = src.read(1).astype(np.float32)
                        height, width = src.height, src.width
                        profile = src.profile
                    
                    # Clean up temporary file
                    os.unlink(tmp_file.name)
                    return data, height, width, profile
                    
            except Exception as e2:
                print(f"Method 2 (temporary file) failed: {e2}")
                try:
                    # Method 3: Try with different suffix and basic profile
                    with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
                        tmp_file.write(jp2_bytes)
                        tmp_file.flush()
                        
                        with rasterio.open(tmp_file.name) as src:
                            data = src.read(1).astype(np.float32)
                            height, width = src.height, src.width
                            profile = {'driver': 'GTiff', 'height': height, 'width': width, 'count': 1, 'dtype': 'float32'}
                        
                        os.unlink(tmp_file.name)
                        return data, height, width, profile
                        
                except Exception as e3:
                    print(f"Method 3 (tiff fallback) failed: {e3}")
                    # Method 4: Final fallback - try to interpret as raw numpy array
                    try:
                        # This assumes the data might be raw numpy bytes as fallback
                        data_array = np.frombuffer(jp2_bytes, dtype=np.float32)
                        
                        # Try to guess square dimensions
                        side_length = int(np.sqrt(len(data_array)))
                        if side_length * side_length == len(data_array):
                            data = data_array.reshape(side_length, side_length)
                            height, width = side_length, side_length
                            profile = {'driver': 'GTiff', 'height': height, 'width': width, 'count': 1, 'dtype': 'float32'}
                            return data, height, width, profile
                        else:
                            # Try common satellite image dimensions
                            common_dims = [(10980, 10980), (5490, 5490), (1024, 1024), (512, 512)]
                            for h, w in common_dims:
                                if h * w == len(data_array):
                                    data = data_array.reshape(h, w)
                                    height, width = h, w
                                    profile = {'driver': 'GTiff', 'height': height, 'width': width, 'count': 1, 'dtype': 'float32'}
                                    return data, height, width, profile
                            
                            raise ValueError(f"Cannot interpret data array of length {len(data_array)} as image")
                            
                    except Exception as e4:
                        raise Exception(f"All fallback methods failed: MemoryFile({e1}), TempFile({e2}), TiffFallback({e3}), RawBytes({e4})")

    def safe_resample_data(self, data, current_height, current_width, target_height, target_width, profile):
        """
        Safely resample data to target dimensions with fallback methods
        """
        if current_height == target_height and current_width == target_width:
            return data
            
        try:
            # Method 1: Use rasterio resampling
            temp_profile = profile.copy()
            temp_profile.update({
                'height': current_height,
                'width': current_width,
                'count': 1,
                'dtype': 'float32'
            })
            
            with MemoryFile() as memfile:
                with memfile.open(**temp_profile) as temp_dataset:
                    temp_dataset.write(data, 1)
                    
                    resampled = temp_dataset.read(
                        out_shape=(1, target_height, target_width),
                        resampling=Resampling.bilinear
                    )[0].astype(np.float32)
                    
                    return resampled
                    
        except Exception as e1:
            print(f"Rasterio resampling failed: {e1}")
            try:
                # Method 2: Use scipy if available
                from scipy import ndimage
                zoom_factors = (target_height / current_height, target_width / current_width)
                resampled = ndimage.zoom(data, zoom_factors, order=1)
                return resampled.astype(np.float32)
                
            except ImportError:
                print("Scipy not available for resampling")
                # Method 3: Simple nearest-neighbor resampling
                h_indices = np.round(np.linspace(0, current_height - 1, target_height)).astype(int)
                w_indices = np.round(np.linspace(0, current_width - 1, target_width)).astype(int)
                
                resampled = data[np.ix_(h_indices, w_indices)]
                return resampled.astype(np.float32)
                
            except Exception as e2:
                print(f"Scipy resampling failed: {e2}")
                # Method 3: Simple nearest-neighbor resampling
                h_indices = np.round(np.linspace(0, current_height - 1, target_height)).astype(int)
                w_indices = np.round(np.linspace(0, current_width - 1, target_width)).astype(int)
                
                resampled = data[np.ix_(h_indices, w_indices)]
                return resampled.astype(np.float32)

    def execute(self, requests):
        """
        Process inference requests with robust error handling.
        """
        responses = []

        for request in requests:
            try:
                input_tensor = pb_utils.get_input_tensor_by_name(request, "input_jp2_bytes")
                jp2_bytes_list = input_tensor.as_numpy()

                if len(jp2_bytes_list) != 3:
                    error_msg = f"Expected 3 JP2 byte strings, received {len(jp2_bytes_list)}"
                    error = pb_utils.TritonError(error_msg)
                    response = pb_utils.InferenceResponse(output_tensors=[], error=error)
                    responses.append(response)
                    continue

                # The input might be hex strings, decode them to bytes first
                red_hex = jp2_bytes_list[0]
                green_hex = jp2_bytes_list[1] 
                nir_hex = jp2_bytes_list[2]

                # Convert hex strings to bytes
                try:
                    if isinstance(red_hex, str):
                        red_bytes = bytes.fromhex(red_hex)
                        green_bytes = bytes.fromhex(green_hex)
                        nir_bytes = bytes.fromhex(nir_hex)
                        print(f"Decoded hex strings to bytes")
                    elif isinstance(red_hex, (bytes, np.bytes_)):
                        # Already bytes, use directly
                        red_bytes = bytes(red_hex)
                        green_bytes = bytes(green_hex) 
                        nir_bytes = bytes(nir_hex)
                        print(f"Input already in bytes format")
                    else:
                        # Might be numpy string object
                        red_bytes = bytes.fromhex(str(red_hex))
                        green_bytes = bytes.fromhex(str(green_hex))
                        nir_bytes = bytes.fromhex(str(nir_hex))
                        print(f"Converted numpy strings to bytes")
                except Exception as e:
                    error_msg = f"Failed to decode input data: {str(e)}"
                    print(f"Decode error: {error_msg}")
                    error = pb_utils.TritonError(error_msg)
                    response = pb_utils.InferenceResponse(output_tensors=[], error=error)
                    responses.append(response)
                    continue

                print(f"Processing JP2 data - decoded sizes: Red={len(red_bytes)}, Green={len(green_bytes)}, NIR={len(nir_bytes)}")

                # Read red band data (use as reference for dimensions)
                red_data, target_height, target_width, red_profile = self.safe_read_jp2_bytes(red_bytes)
                print(f"Red band: {red_data.shape}, target dimensions: {target_height}x{target_width}")

                # Read and resample green band
                green_data, green_height, green_width, green_profile = self.safe_read_jp2_bytes(green_bytes)
                green_data = self.safe_resample_data(green_data, green_height, green_width, target_height, target_width, green_profile)
                print(f"Green band after resampling: {green_data.shape}")

                # Read and resample NIR band
                nir_data, nir_height, nir_width, nir_profile = self.safe_read_jp2_bytes(nir_bytes)
                nir_data = self.safe_resample_data(nir_data, nir_height, nir_width, target_height, target_width, nir_profile)
                print(f"NIR band after resampling: {nir_data.shape}")

                # Verify all bands have the same shape
                if not (red_data.shape == green_data.shape == nir_data.shape):
                    shapes = [red_data.shape, green_data.shape, nir_data.shape]
                    error_msg = f"Band shape mismatch after resampling: {shapes}"
                    error = pb_utils.TritonError(error_msg)
                    response = pb_utils.InferenceResponse(output_tensors=[], error=error)
                    responses.append(response)
                    continue

                # Check for valid dimensions
                if red_data.shape[0] == 0 or red_data.shape[1] == 0:
                    error_msg = f"Invalid band dimensions: {red_data.shape}"
                    error = pb_utils.TritonError(error_msg)
                    response = pb_utils.InferenceResponse(output_tensors=[], error=error)
                    responses.append(response)
                    continue

                # Stack bands in CHW format for prediction (channels, height, width)
                prediction_array = np.stack([red_data, green_data, nir_data], axis=0)
                print(f"Final prediction array shape: {prediction_array.shape}")

                # Run cloud detection prediction
                cloud_mask = predict_from_array(prediction_array)
                print(f"Cloud mask shape: {cloud_mask.shape}")

                # Flatten the mask for output
                if cloud_mask.ndim > 1:
                    cloud_mask = cloud_mask.flatten()

                # Create output tensor (config expects TYPE_UINT8)
                output_tensor = pb_utils.Tensor("output_mask", cloud_mask.astype(np.uint8))
                response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
                responses.append(response)

            except Exception as e:
                # Enhanced error reporting
                error_msg = f"Error processing JP2 data: {str(e)}"
                print(f"Model execution error: {error_msg}")
                error = pb_utils.TritonError(error_msg)
                response = pb_utils.InferenceResponse(output_tensors=[], error=error)
                responses.append(response)

        return responses

    def finalize(self):
        """
        Clean up when the model is unloaded.
        """
        print('Cloud Detection model finalized')