truthdotphd commited on
Commit
0d9b472
·
verified ·
1 Parent(s): 261df98

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +66 -228
model.py CHANGED
@@ -4,259 +4,97 @@ from omnicloudmask import predict_from_array
4
  import rasterio
5
  from rasterio.io import MemoryFile
6
  from rasterio.enums import Resampling
7
- import tempfile
8
- import os
9
- from io import BytesIO
10
 
11
  class TritonPythonModel:
12
  def initialize(self, args):
13
  """
14
  Initialize the model. This function is called once when the model is loaded.
15
  """
16
- print('Initialized Cloud Detection model with JP2 input and robust GDAL handling')
17
-
18
- def safe_read_jp2_bytes(self, jp2_bytes):
19
- """
20
- Safely read JP2 bytes with multiple fallback methods
21
- """
22
- try:
23
- # Method 1: Try direct MemoryFile approach (works if GDAL drivers are properly configured)
24
- with MemoryFile(jp2_bytes) as memfile:
25
- with memfile.open() as src:
26
- data = src.read(1).astype(np.float32)
27
- height, width = src.height, src.width
28
- profile = src.profile
29
- return data, height, width, profile
30
-
31
- except Exception as e1:
32
- print(f"Method 1 (MemoryFile) failed: {e1}")
33
- try:
34
- # Method 2: Write to temporary file and read from disk
35
- with tempfile.NamedTemporaryFile(delete=False, suffix='.jp2') as tmp_file:
36
- tmp_file.write(jp2_bytes)
37
- tmp_file.flush()
38
-
39
- with rasterio.open(tmp_file.name) as src:
40
- data = src.read(1).astype(np.float32)
41
- height, width = src.height, src.width
42
- profile = src.profile
43
-
44
- # Clean up temporary file
45
- os.unlink(tmp_file.name)
46
- return data, height, width, profile
47
-
48
- except Exception as e2:
49
- print(f"Method 2 (temporary file) failed: {e2}")
50
- try:
51
- # Method 3: Try with different suffix and basic profile
52
- with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
53
- tmp_file.write(jp2_bytes)
54
- tmp_file.flush()
55
-
56
- with rasterio.open(tmp_file.name) as src:
57
- data = src.read(1).astype(np.float32)
58
- height, width = src.height, src.width
59
- profile = {'driver': 'GTiff', 'height': height, 'width': width, 'count': 1, 'dtype': 'float32'}
60
-
61
- os.unlink(tmp_file.name)
62
- return data, height, width, profile
63
-
64
- except Exception as e3:
65
- print(f"Method 3 (tiff fallback) failed: {e3}")
66
- # Method 4: Final fallback - try to interpret as raw numpy array
67
- try:
68
- # This assumes the data might be raw numpy bytes as fallback
69
- data_array = np.frombuffer(jp2_bytes, dtype=np.float32)
70
-
71
- # Try to guess square dimensions
72
- side_length = int(np.sqrt(len(data_array)))
73
- if side_length * side_length == len(data_array):
74
- data = data_array.reshape(side_length, side_length)
75
- height, width = side_length, side_length
76
- profile = {'driver': 'GTiff', 'height': height, 'width': width, 'count': 1, 'dtype': 'float32'}
77
- return data, height, width, profile
78
- else:
79
- # Try common satellite image dimensions
80
- common_dims = [(10980, 10980), (5490, 5490), (1024, 1024), (512, 512)]
81
- for h, w in common_dims:
82
- if h * w == len(data_array):
83
- data = data_array.reshape(h, w)
84
- height, width = h, w
85
- profile = {'driver': 'GTiff', 'height': height, 'width': width, 'count': 1, 'dtype': 'float32'}
86
- return data, height, width, profile
87
-
88
- raise ValueError(f"Cannot interpret data array of length {len(data_array)} as image")
89
-
90
- except Exception as e4:
91
- raise Exception(f"All fallback methods failed: MemoryFile({e1}), TempFile({e2}), TiffFallback({e3}), RawBytes({e4})")
92
-
93
- def safe_resample_data(self, data, current_height, current_width, target_height, target_width, profile):
94
- """
95
- Safely resample data to target dimensions with fallback methods
96
- """
97
- if current_height == target_height and current_width == target_width:
98
- return data
99
-
100
- try:
101
- # Method 1: Use rasterio resampling
102
- temp_profile = profile.copy()
103
- temp_profile.update({
104
- 'height': current_height,
105
- 'width': current_width,
106
- 'count': 1,
107
- 'dtype': 'float32'
108
- })
109
-
110
- with MemoryFile() as memfile:
111
- with memfile.open(**temp_profile) as temp_dataset:
112
- temp_dataset.write(data, 1)
113
-
114
- resampled = temp_dataset.read(
115
- out_shape=(1, target_height, target_width),
116
- resampling=Resampling.bilinear
117
- )[0].astype(np.float32)
118
-
119
- return resampled
120
-
121
- except Exception as e1:
122
- print(f"Rasterio resampling failed: {e1}")
123
- try:
124
- # Method 2: Use scipy if available
125
- from scipy import ndimage
126
- zoom_factors = (target_height / current_height, target_width / current_width)
127
- resampled = ndimage.zoom(data, zoom_factors, order=1)
128
- return resampled.astype(np.float32)
129
-
130
- except ImportError:
131
- print("Scipy not available for resampling")
132
- # Method 3: Simple nearest-neighbor resampling
133
- h_indices = np.round(np.linspace(0, current_height - 1, target_height)).astype(int)
134
- w_indices = np.round(np.linspace(0, current_width - 1, target_width)).astype(int)
135
-
136
- resampled = data[np.ix_(h_indices, w_indices)]
137
- return resampled.astype(np.float32)
138
-
139
- except Exception as e2:
140
- print(f"Scipy resampling failed: {e2}")
141
- # Method 3: Simple nearest-neighbor resampling
142
- h_indices = np.round(np.linspace(0, current_height - 1, target_height)).astype(int)
143
- w_indices = np.round(np.linspace(0, current_width - 1, target_width)).astype(int)
144
-
145
- resampled = data[np.ix_(h_indices, w_indices)]
146
- return resampled.astype(np.float32)
147
 
148
  def execute(self, requests):
149
  """
150
- Process inference requests with robust error handling.
151
  """
152
  responses = []
153
-
154
  for request in requests:
155
- try:
156
- input_tensor = pb_utils.get_input_tensor_by_name(request, "input_jp2_bytes")
157
- jp2_bytes_list = input_tensor.as_numpy()
158
-
159
- if len(jp2_bytes_list) != 3:
160
- error_msg = f"Expected 3 JP2 byte strings, received {len(jp2_bytes_list)}"
161
- error = pb_utils.TritonError(error_msg)
162
- response = pb_utils.InferenceResponse(output_tensors=[], error=error)
163
- responses.append(response)
164
- continue
165
-
166
- # The input might be hex strings, decode them to bytes first
167
- red_hex = jp2_bytes_list[0]
168
- green_hex = jp2_bytes_list[1]
169
- nir_hex = jp2_bytes_list[2]
170
-
171
- # Convert hex strings to bytes
172
- try:
173
- if isinstance(red_hex, str):
174
- red_bytes = bytes.fromhex(red_hex)
175
- green_bytes = bytes.fromhex(green_hex)
176
- nir_bytes = bytes.fromhex(nir_hex)
177
- print(f"Decoded hex strings to bytes")
178
- elif isinstance(red_hex, (bytes, np.bytes_)):
179
- # Already bytes, use directly
180
- red_bytes = bytes(red_hex)
181
- green_bytes = bytes(green_hex)
182
- nir_bytes = bytes(nir_hex)
183
- print(f"Input already in bytes format")
184
- else:
185
- # Might be numpy string object
186
- red_bytes = bytes.fromhex(str(red_hex))
187
- green_bytes = bytes.fromhex(str(green_hex))
188
- nir_bytes = bytes.fromhex(str(nir_hex))
189
- print(f"Converted numpy strings to bytes")
190
- except Exception as e:
191
- error_msg = f"Failed to decode input data: {str(e)}"
192
- print(f"Decode error: {error_msg}")
193
- error = pb_utils.TritonError(error_msg)
194
- response = pb_utils.InferenceResponse(output_tensors=[], error=error)
195
- responses.append(response)
196
- continue
197
-
198
- print(f"Processing JP2 data - decoded sizes: Red={len(red_bytes)}, Green={len(green_bytes)}, NIR={len(nir_bytes)}")
199
-
200
- # Read red band data (use as reference for dimensions)
201
- red_data, target_height, target_width, red_profile = self.safe_read_jp2_bytes(red_bytes)
202
- print(f"Red band: {red_data.shape}, target dimensions: {target_height}x{target_width}")
203
-
204
- # Read and resample green band
205
- green_data, green_height, green_width, green_profile = self.safe_read_jp2_bytes(green_bytes)
206
- green_data = self.safe_resample_data(green_data, green_height, green_width, target_height, target_width, green_profile)
207
- print(f"Green band after resampling: {green_data.shape}")
208
 
209
- # Read and resample NIR band
210
- nir_data, nir_height, nir_width, nir_profile = self.safe_read_jp2_bytes(nir_bytes)
211
- nir_data = self.safe_resample_data(nir_data, nir_height, nir_width, target_height, target_width, nir_profile)
212
- print(f"NIR band after resampling: {nir_data.shape}")
213
 
214
- # Verify all bands have the same shape
215
- if not (red_data.shape == green_data.shape == nir_data.shape):
216
- shapes = [red_data.shape, green_data.shape, nir_data.shape]
217
- error_msg = f"Band shape mismatch after resampling: {shapes}"
218
- error = pb_utils.TritonError(error_msg)
219
- response = pb_utils.InferenceResponse(output_tensors=[], error=error)
220
- responses.append(response)
221
- continue
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
- # Check for valid dimensions
224
- if red_data.shape[0] == 0 or red_data.shape[1] == 0:
225
- error_msg = f"Invalid band dimensions: {red_data.shape}"
226
- error = pb_utils.TritonError(error_msg)
227
- response = pb_utils.InferenceResponse(output_tensors=[], error=error)
228
- responses.append(response)
229
- continue
230
 
231
- # Stack bands in CHW format for prediction (channels, height, width)
232
- prediction_array = np.stack([red_data, green_data, nir_data], axis=0)
233
- print(f"Final prediction array shape: {prediction_array.shape}")
 
 
 
 
 
234
 
235
- # Run cloud detection prediction
236
- cloud_mask = predict_from_array(prediction_array)
237
- print(f"Cloud mask shape: {cloud_mask.shape}")
238
 
239
- # Flatten the mask for output
240
- if cloud_mask.ndim > 1:
241
- cloud_mask = cloud_mask.flatten()
242
 
243
- # Create output tensor (config expects TYPE_UINT8)
244
- output_tensor = pb_utils.Tensor("output_mask", cloud_mask.astype(np.uint8))
245
- response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
246
- responses.append(response)
 
 
247
 
248
  except Exception as e:
249
- # Enhanced error reporting
250
- error_msg = f"Error processing JP2 data: {str(e)}"
251
- print(f"Model execution error: {error_msg}")
252
- error = pb_utils.TritonError(error_msg)
253
  response = pb_utils.InferenceResponse(output_tensors=[], error=error)
254
- responses.append(response)
255
 
 
 
 
256
  return responses
257
 
258
  def finalize(self):
259
  """
260
- Clean up when the model is unloaded.
261
  """
262
- print('Cloud Detection model finalized')
 
4
  import rasterio
5
  from rasterio.io import MemoryFile
6
  from rasterio.enums import Resampling
 
 
 
7
 
8
  class TritonPythonModel:
9
  def initialize(self, args):
10
  """
11
  Initialize the model. This function is called once when the model is loaded.
12
  """
13
+ # You can load models or initialize resources here if needed.
14
+ # Ensure rasterio is installed in the Python backend environment.
15
+ print('Initialized Cloud Detection model with JP2 input')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def execute(self, requests):
18
  """
19
+ Process inference requests.
20
  """
21
  responses = []
22
+ # Every request must contain three JP2 byte strings (Red, Green, NIR).
23
  for request in requests:
24
+ # Get the input tensor containing the byte arrays
25
+ input_tensor = pb_utils.get_input_tensor_by_name(request, "input_jp2_bytes")
26
+ # as_numpy() for TYPE_STRING gives an ndarray of Python bytes objects
27
+ jp2_bytes_list = input_tensor.as_numpy()
28
+
29
+ if len(jp2_bytes_list) != 3:
30
+ # Send an error response if the input shape is incorrect
31
+ error = pb_utils.TritonError(f"Expected 3 JP2 byte strings, received {len(jp2_bytes_list)}")
32
+ response = pb_utils.InferenceResponse(output_tensors=[], error=error)
33
+ responses.append(response)
34
+ continue # Skip to the next request
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ # Assume order: Red, Green, NIR based on client logic
37
+ red_bytes = jp2_bytes_list[0]
38
+ green_bytes = jp2_bytes_list[1]
39
+ nir_bytes = jp2_bytes_list[2]
40
 
41
+ try:
42
+ # Process JP2 bytes using rasterio in memory
43
+ with MemoryFile(red_bytes) as memfile_red:
44
+ with memfile_red.open() as src_red:
45
+ red_data = src_red.read(1).astype(np.float32)
46
+ target_height = src_red.height
47
+ target_width = src_red.width
48
+
49
+ with MemoryFile(green_bytes) as memfile_green:
50
+ with memfile_green.open() as src_green:
51
+ # Ensure green band matches red band dimensions (should if B03)
52
+ if src_green.height != target_height or src_green.width != target_width:
53
+ # Optional: Resample green if necessary, though B03 usually matches B04
54
+ green_data = src_green.read(
55
+ 1,
56
+ out_shape=(1, target_height, target_width),
57
+ resampling=Resampling.bilinear
58
+ ).astype(np.float32)
59
+ else:
60
+ green_data = src_green.read(1).astype(np.float32)
61
 
 
 
 
 
 
 
 
62
 
63
+ with MemoryFile(nir_bytes) as memfile_nir:
64
+ with memfile_nir.open() as src_nir:
65
+ # Resample NIR (B8A) to match Red/Green (B04/B03) resolution
66
+ nir_data = src_nir.read(
67
+ 1, # Read the first band
68
+ out_shape=(1, target_height, target_width),
69
+ resampling=Resampling.bilinear
70
+ ).astype(np.float32)
71
 
72
+ # Stack bands in CHW format (Red, Green, NIR) for the model
73
+ # Match the channel order expected by predict_from_array
74
+ input_array = np.stack([red_data, green_data, nir_data], axis=0)
75
 
76
+ # Perform inference using the original function
77
+ pred_mask = predict_from_array(input_array)
 
78
 
79
+ # Create output tensor
80
+ output_tensor = pb_utils.Tensor(
81
+ "output_mask",
82
+ pred_mask.astype(np.uint8)
83
+ )
84
+ response = pb_utils.InferenceResponse([output_tensor])
85
 
86
  except Exception as e:
87
+ # Handle errors during processing (e.g., invalid JP2 data)
88
+ error = pb_utils.TritonError(f"Error processing JP2 data: {str(e)}")
 
 
89
  response = pb_utils.InferenceResponse(output_tensors=[], error=error)
 
90
 
91
+ responses.append(response)
92
+
93
+ # Return a list of responses
94
  return responses
95
 
96
  def finalize(self):
97
  """
98
+ Called when the model is unloaded. Perform any necessary cleanup.
99
  """
100
+ print('Finalizing Cloud Detection model')