Amir Erfan Eshratifar commited on
Commit
241b6a2
·
1 Parent(s): 551ee08

model checkpoints, sample input, readme

Browse files
README.md CHANGED
@@ -1,6 +1,35 @@
1
- ---
2
- license: mit
3
- ---
4
-
5
-
6
- Python 3.12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - csaybar/CloudSEN12-high
5
+ language:
6
+ - en
7
+ base_model:
8
+ - NickWright/OmniCloudMask
9
+ tags:
10
+ - remote-sensing
11
+ - cloud-detection
12
+ ---
13
+
14
+
15
+ # Cloud Detection Model
16
+
17
+ This model is based on NickWright/OmniCloudMask for cloud detection in satellite imagery. It provides pixel-level segmentation with the following classes:
18
+
19
+ 0 = Clear
20
+ 1 = Thick Cloud
21
+ 2 = Thin Cloud
22
+ 3 = Cloud Shadow
23
+
24
+ ## Usage
25
+
26
+ The model requires Python 3.10 or higher. To use this model:
27
+ ```bash
28
+ pip install -r requirements.txt
29
+ ```
30
+ ```bash
31
+ python3 model.py
32
+ ```
33
+
34
+ Below is a visualization of the cloud mask generated by the model:
35
+ ![Cloud Mask Visualization](cloud_mask_visualization.png)
jp2s/B02.jp2 ADDED
jp2s/B03.jp2 ADDED
jp2s/B04.jp2 ADDED
jp2s/B8A.jp2 ADDED
omnicloudmask/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .__version__ import __version__
2
+ from .cloud_mask import predict_from_array, predict_from_load_func
3
+ from .data_loaders import (
4
+ load_ls8,
5
+ load_multiband,
6
+ load_s2,
7
+ )
8
+
9
+ __all__ = [
10
+ "predict_from_load_func",
11
+ "predict_from_array",
12
+ "load_ls8",
13
+ "load_multiband",
14
+ "load_s2",
15
+ "__version__",
16
+ ]
omnicloudmask/__version__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "1.0.9"
omnicloudmask/cloud_mask.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from concurrent.futures import ThreadPoolExecutor, as_completed
3
+ from pathlib import Path
4
+ from threading import Thread
5
+ from typing import Callable, Generator, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from rasterio.profiles import Profile
10
+ from tqdm.auto import tqdm
11
+
12
+ from .__version__ import __version__
13
+ from .download_models import get_models
14
+ from .model_utils import (
15
+ create_gradient_mask,
16
+ default_device,
17
+ get_torch_dtype,
18
+ inference_and_store,
19
+ load_model_from_weights,
20
+ )
21
+ from .raster_utils import (
22
+ get_patch,
23
+ make_patch_indexes,
24
+ mask_prediction,
25
+ save_prediction,
26
+ )
27
+
28
+
29
+ def compile_batches(
30
+ batch_size: int,
31
+ patch_size: int,
32
+ patch_indexes: list[tuple[int, int, int, int]],
33
+ input_array: np.ndarray,
34
+ no_data_value: int,
35
+ inference_device: torch.device,
36
+ inference_dtype: torch.dtype,
37
+ ) -> Generator[tuple[torch.Tensor, list[tuple[int, int, int, int]]], None, None]:
38
+ """Used to compile batches of patches from the input array and return them as a generator."""
39
+
40
+ with ThreadPoolExecutor(max_workers=batch_size) as executor:
41
+ futures = [
42
+ executor.submit(get_patch, input_array, index, no_data_value)
43
+ for index in patch_indexes
44
+ ]
45
+
46
+ total_futures = len(futures)
47
+ all_indexes = set()
48
+ index_batch = []
49
+ patch_batch_array = np.zeros(
50
+ (batch_size, input_array.shape[0], patch_size, patch_size), dtype=np.float32
51
+ )
52
+
53
+ for index, future in enumerate(as_completed(futures)):
54
+ patch, new_index = future.result()
55
+
56
+ if patch is not None and new_index not in all_indexes:
57
+ index_batch.append(new_index)
58
+ patch_batch_array[len(index_batch) - 1] = patch
59
+ all_indexes.add(new_index)
60
+
61
+ if len(index_batch) == batch_size or index == total_futures - 1:
62
+ if len(index_batch) == 0:
63
+ continue
64
+ input_tensor = (
65
+ torch.tensor(patch_batch_array[: len(index_batch)])
66
+ .to(inference_device)
67
+ .to(inference_dtype)
68
+ )
69
+ yield input_tensor, index_batch
70
+ index_batch = []
71
+
72
+
73
+ def run_models_on_array(
74
+ models: list[torch.nn.Module],
75
+ input_array: np.ndarray,
76
+ pred_tracker: torch.Tensor,
77
+ grad_tracker: Union[torch.Tensor, None],
78
+ patch_size: int,
79
+ patch_overlap: int,
80
+ inference_device: torch.device,
81
+ batch_size: int = 2,
82
+ inference_dtype: torch.dtype = torch.float32,
83
+ no_data_value: int = 0,
84
+ ) -> None:
85
+ """Used to execute the model on the input array, in patches. Predictions are stored in pred_tracker and grad_tracker, updated in place."""
86
+ patch_indexes = make_patch_indexes(
87
+ array_height=input_array.shape[1],
88
+ array_width=input_array.shape[2],
89
+ patch_size=patch_size,
90
+ patch_overlap=patch_overlap,
91
+ )
92
+
93
+ gradient = create_gradient_mask(
94
+ patch_size, patch_overlap, device=inference_device, dtype=inference_dtype
95
+ )
96
+
97
+ input_tensor_gen = compile_batches(
98
+ batch_size=batch_size,
99
+ patch_size=patch_size,
100
+ patch_indexes=patch_indexes,
101
+ input_array=input_array,
102
+ no_data_value=no_data_value,
103
+ inference_device=inference_device,
104
+ inference_dtype=inference_dtype,
105
+ )
106
+
107
+ for patch_batch, index_batch in input_tensor_gen:
108
+ inference_and_store(
109
+ models=models,
110
+ patch_batch=patch_batch,
111
+ index_batch=index_batch,
112
+ pred_tracker=pred_tracker,
113
+ gradient=gradient,
114
+ grad_tracker=grad_tracker,
115
+ )
116
+
117
+
118
+ def check_patch_size(
119
+ input_array: np.ndarray, no_data_value: int, patch_size: int, patch_overlap: int
120
+ ) -> tuple[int, int]:
121
+ """Used to check the inputs and adjust the patch size and overlap if necessary."""
122
+ # check the shape of the input array
123
+ if len(input_array.shape) != 3:
124
+ raise ValueError(
125
+ f"Input array must have 3 dimensions, found {len(input_array.shape)}. The input should be in format (bands (red,green,NIR), height, width)."
126
+ )
127
+
128
+ # check the width and height are greater than 10 pixels
129
+ if min(input_array.shape[1], input_array.shape[2]) < 10:
130
+ raise ValueError(
131
+ f"Input array must have a width and height greater than 10 pixels, found shape {input_array.shape}. The input should be in format (bands (red,green,NIR), height, width)."
132
+ )
133
+ if min(input_array.shape[1], input_array.shape[2]) < 50:
134
+ warnings.warn(
135
+ f"Input width or height is less than 50 pixels, found shape {input_array.shape}. Such a small image may not provide adequate spatial context for the model."
136
+ )
137
+
138
+ # if the input has a lot of no data values and the patch size is larger than half the image size, we reduce the patch size and overlap
139
+ if np.count_nonzero(input_array == no_data_value) / input_array.size > 0.3:
140
+ if patch_size > min(input_array.shape[1], input_array.shape[2]) / 2:
141
+ patch_size = min(input_array.shape[1], input_array.shape[2]) // 2
142
+ if patch_size // 2 < patch_overlap:
143
+ patch_overlap = patch_size // 2
144
+
145
+ warnings.warn(
146
+ f"Significant no-data areas detected. Adjusting patch size to {patch_size}px and overlap to {patch_overlap}px to minimize no-data patches."
147
+ )
148
+
149
+ # if the patch size is larger than the image size, we reduce the patch size and overlap
150
+ if patch_size > min(input_array.shape[1], input_array.shape[2]):
151
+ patch_size = min(input_array.shape[1], input_array.shape[2])
152
+ if patch_size // 2 < patch_overlap:
153
+ patch_overlap = patch_size // 2
154
+ warnings.warn(
155
+ f"Patch size too large, reducing to {patch_size} and overlap to {patch_overlap}."
156
+ )
157
+
158
+ # if the patch overlap is larger than the patch size, raise an error
159
+ if patch_overlap >= patch_size:
160
+ raise ValueError(
161
+ f"Patch overlap {patch_overlap}px must be less than patch size {patch_size}px."
162
+ )
163
+ return patch_overlap, patch_size
164
+
165
+
166
+ def coordinator(
167
+ input_array: np.ndarray,
168
+ models: list[torch.nn.Module],
169
+ inference_dtype: torch.dtype,
170
+ export_confidence: bool,
171
+ softmax_output: bool,
172
+ inference_device: torch.device,
173
+ mosaic_device: torch.device,
174
+ patch_size: int,
175
+ patch_overlap: int,
176
+ batch_size: int,
177
+ profile: Profile = Profile(),
178
+ output_path: Path = Path(""),
179
+ no_data_value: int = 0,
180
+ pbar: Optional[tqdm] = None,
181
+ apply_no_data_mask: bool = False,
182
+ export_to_disk: bool = True,
183
+ save_executor: Optional[ThreadPoolExecutor] = None,
184
+ pred_classes: int = 4,
185
+ ) -> np.ndarray:
186
+ """Used to coordinate the process of predicting from an input array."""
187
+
188
+ patch_overlap, patch_size = check_patch_size(
189
+ input_array, no_data_value, patch_size, patch_overlap
190
+ )
191
+
192
+ pred_tracker = torch.zeros(
193
+ (pred_classes, *input_array.shape[1:3]),
194
+ dtype=inference_dtype,
195
+ device=mosaic_device,
196
+ )
197
+
198
+ grad_tracker = (
199
+ torch.zeros(input_array.shape[1:3], dtype=inference_dtype, device=mosaic_device)
200
+ if export_confidence
201
+ else None
202
+ )
203
+
204
+ run_models_on_array(
205
+ models=models,
206
+ input_array=input_array,
207
+ pred_tracker=pred_tracker,
208
+ grad_tracker=grad_tracker,
209
+ inference_device=inference_device,
210
+ inference_dtype=inference_dtype,
211
+ no_data_value=no_data_value,
212
+ patch_size=patch_size,
213
+ patch_overlap=patch_overlap,
214
+ batch_size=batch_size,
215
+ )
216
+
217
+ if export_confidence:
218
+ pred_tracker_norm = pred_tracker / grad_tracker
219
+ if softmax_output:
220
+ pred_tracker = torch.clip(
221
+ (torch.nn.functional.softmax(pred_tracker_norm, 0) + 0.001),
222
+ 0.001,
223
+ 0.999,
224
+ )
225
+ else:
226
+ pred_tracker = pred_tracker_norm
227
+
228
+ pred_tracker_np = pred_tracker.float().numpy(force=True)
229
+
230
+ else:
231
+ pred_tracker_np = (
232
+ torch.argmax(pred_tracker, 0, keepdim=True)
233
+ .numpy(force=True)
234
+ .astype(np.uint8)
235
+ )
236
+
237
+ if apply_no_data_mask:
238
+ pred_tracker_np = mask_prediction(input_array, pred_tracker_np, no_data_value)
239
+
240
+ if export_to_disk:
241
+ export_profile = profile.copy()
242
+ export_profile.update(
243
+ dtype=pred_tracker_np.dtype,
244
+ count=pred_tracker_np.shape[0],
245
+ compress="lzw",
246
+ nodata=0,
247
+ driver="GTiff",
248
+ )
249
+ # if executer has been passed, submit the save_prediction function to it, to avoid blocking the main thread
250
+ if save_executor:
251
+ save_executor.submit(
252
+ save_prediction, output_path, export_profile, pred_tracker_np
253
+ )
254
+ # otherwise save the prediction directly
255
+
256
+ else:
257
+ save_prediction(output_path, export_profile, pred_tracker_np)
258
+
259
+ if pbar:
260
+ pbar.update(1)
261
+ return pred_tracker_np
262
+
263
+
264
+ def collect_models(
265
+ custom_models: Union[list[torch.nn.Module], torch.nn.Module],
266
+ inference_device: torch.device,
267
+ inference_dtype: torch.dtype,
268
+ source: str,
269
+ destination_model_dir: Union[str, Path, None] = None,
270
+ ) -> list[torch.nn.Module]:
271
+ if not custom_models:
272
+ models = []
273
+ for model_details in get_models(model_dir=destination_model_dir, source=source):
274
+ models.append(
275
+ load_model_from_weights(
276
+ model_name=model_details["timm_model_name"],
277
+ weights_path=model_details["Path"],
278
+ device=inference_device,
279
+ dtype=inference_dtype,
280
+ )
281
+ )
282
+ else:
283
+ # if not a list, make it a list of models
284
+ if not isinstance(custom_models, list):
285
+ custom_models = [custom_models]
286
+
287
+ models = [
288
+ model.to(inference_dtype).to(inference_device) for model in custom_models
289
+ ]
290
+ return models
291
+
292
+
293
+ def predict_from_array(
294
+ input_array: np.ndarray,
295
+ patch_size: int = 1000,
296
+ patch_overlap: int = 300,
297
+ batch_size: int = 1,
298
+ inference_device: Union[str, torch.device] = default_device(),
299
+ mosaic_device: Optional[Union[str, torch.device]] = None,
300
+ inference_dtype: Union[torch.dtype, str] = torch.float32,
301
+ export_confidence: bool = False,
302
+ softmax_output: bool = True,
303
+ no_data_value: int = 0,
304
+ apply_no_data_mask: bool = True,
305
+ custom_models: Union[list[torch.nn.Module], torch.nn.Module] = [],
306
+ pred_classes: int = 4,
307
+ destination_model_dir: Union[str, Path, None] = None,
308
+ model_download_source: str = "google_drive",
309
+ ) -> np.ndarray:
310
+ """Predict a cloud and cloud shadow mask from a Red, Green and NIR numpy array, with a spatial res between 10 m and 50 m.
311
+
312
+ Args:
313
+ input_array (np.ndarray): A numpy array with shape (3, height, width) representing the Red, Green and NIR bands.
314
+ patch_size (int, optional): Size of the patches for inference. Defaults to 1000.
315
+ patch_overlap (int, optional): Overlap between patches for inference. Defaults to 300.
316
+ batch_size (int, optional): Number of patches to process in a batch. Defaults to 1.
317
+ inference_device (Union[str, torch.device], optional): Device to use for inference (e.g., 'cpu', 'cuda', 'mps'). Defaults to the device returned by default_device().
318
+ mosaic_device (Union[str, torch.device], optional): Device to use for mosaicking patches. Defaults to inference device.
319
+ inference_dtype (Union[torch.dtype, str], optional): Data type for inference. Defaults to torch.float32.
320
+ export_confidence (bool, optional): If True, exports confidence maps instead of predicted classes. Defaults to False.
321
+ softmax_output (bool, optional): If True, applies a softmax to the output, only used if export_confidence = True. Defaults to True.
322
+ no_data_value (int, optional): Value within input scenes that specifies no data region. Defaults to 0.
323
+ apply_no_data_mask (bool, optional): If True, applies a no-data mask to the predictions. Defaults to True.
324
+ custom_models Union[list[torch.nn.Module], torch.nn.Module], optional): A list or singular custom torch models to use for prediction. Defaults to [].
325
+ pred_classes (int, optional): Number of classes to predict. Defaults to 4, to be used with custom models.
326
+ destination_model_dir Union[str, Path, None]: Directory to save the model weights. Defaults to None.
327
+ model_download_source (str, optional): Source from which to download the model weights. Defaults to "google_drive", can also be "hugging_face".
328
+ Returns:
329
+ np.ndarray: A numpy array with shape (1, height, width) or (4, height, width if export_confidence = True) representing the predicted cloud and cloud shadow mask.
330
+
331
+ """
332
+
333
+ inference_device = torch.device(inference_device)
334
+ if mosaic_device is None:
335
+ mosaic_device = inference_device
336
+ else:
337
+ mosaic_device = torch.device(mosaic_device)
338
+
339
+ inference_dtype = get_torch_dtype(inference_dtype)
340
+ # if no custom model paths are provided, use the default models
341
+ models = collect_models(
342
+ custom_models=custom_models,
343
+ inference_device=inference_device,
344
+ inference_dtype=inference_dtype,
345
+ source=model_download_source,
346
+ destination_model_dir=destination_model_dir,
347
+ )
348
+
349
+ pred_tracker = coordinator(
350
+ input_array=input_array,
351
+ models=models,
352
+ inference_device=inference_device,
353
+ mosaic_device=mosaic_device,
354
+ inference_dtype=inference_dtype,
355
+ export_confidence=export_confidence,
356
+ softmax_output=softmax_output,
357
+ patch_size=patch_size,
358
+ patch_overlap=patch_overlap,
359
+ batch_size=batch_size,
360
+ no_data_value=no_data_value,
361
+ export_to_disk=False,
362
+ apply_no_data_mask=apply_no_data_mask,
363
+ pred_classes=pred_classes,
364
+ )
365
+
366
+ return pred_tracker
367
+
368
+
369
+ def predict_from_load_func(
370
+ scene_paths: Union[list[Path], list[str]],
371
+ load_func: Callable,
372
+ patch_size: int = 1000,
373
+ patch_overlap: int = 300,
374
+ batch_size: int = 1,
375
+ inference_device: Union[str, torch.device] = default_device(),
376
+ mosaic_device: Optional[Union[str, torch.device]] = None,
377
+ inference_dtype: Union[torch.dtype, str] = torch.float32,
378
+ export_confidence: bool = False,
379
+ softmax_output: bool = True,
380
+ no_data_value: int = 0,
381
+ overwrite: bool = True,
382
+ apply_no_data_mask: bool = True,
383
+ output_dir: Optional[Union[Path, str]] = None,
384
+ custom_models: Union[list[torch.nn.Module], torch.nn.Module] = [],
385
+ destination_model_dir: Union[str, Path, None] = None,
386
+ model_download_source: str = "google_drive",
387
+ ) -> list[Path]:
388
+ """
389
+ Predicts cloud and cloud shadow masks for a list of scenes using a specified loading function.
390
+
391
+ Args:
392
+ scene_paths (Union[list[Path], list[str]]): A list of paths to the scene files to be processed.
393
+ load_func (Callable): A function to load the scene data. This function should take an input_path parameter and return a R,G,NIR numpy array and a rasterio for export profile, several load func are provided within data_loaders.py
394
+ patch_size (int, optional): Size of the patches for inference. Defaults to 1000.
395
+ patch_overlap (int, optional): Overlap between patches for inference. Defaults to 300.
396
+ batch_size (int, optional): Number of patches to process in a batch. Defaults to 1.
397
+ inference_device (Union[str, torch.device], optional): Device to use for inference (e.g., 'cpu', 'cuda', 'mps'). Defaults to the device returned by default_device().
398
+ mosaic_device (Union[str, torch.device], optional): Device to use for mosaicking patches. Defaults to inference device.
399
+ inference_dtype (Union[torch.dtype, str], optional): Data type for inference. Defaults to torch.float32.
400
+ export_confidence (bool, optional): If True, exports confidence maps instead of predicted classes. Defaults to False.
401
+ softmax_output (bool, optional): If True, applies a softmax to the output, only used if export_confidence = True. Defaults to True.
402
+ no_data_value (int, optional): Value within input scenes that specifies no data region. Defaults to 0.
403
+ overwrite (bool, optional): If False, skips scenes that already have a prediction file. Defaults to True.
404
+ apply_no_data_mask (bool, optional): If True, applies a no-data mask to the predictions. Defaults to True.
405
+ output_dir (Optional[Union[Path, str]], optional): Directory to save the prediction files. Defaults to None. If None, the predictions will be saved in the same directory as the input scene.
406
+ custom_models Union[list[torch.nn.Module], torch.nn.Module], optional): A list or singular custom torch models to use for prediction. Defaults to [].
407
+ destination_model_dir Union[str, Path, None]: Directory to save the model weights. Defaults to None.
408
+ model_download_source (str, optional): Source from which to download the model weights. Defaults to "google_drive", can also be "hugging_face".
409
+
410
+ Returns:
411
+ list[Path]: A list of paths to the output prediction files.
412
+
413
+ """
414
+ pred_paths = []
415
+ inf_thread = Thread()
416
+ save_executor = ThreadPoolExecutor(max_workers=1)
417
+
418
+ inference_device = torch.device(inference_device)
419
+ if mosaic_device is None:
420
+ mosaic_device = inference_device
421
+ else:
422
+ mosaic_device = torch.device(mosaic_device)
423
+
424
+ inference_dtype = get_torch_dtype(inference_dtype)
425
+
426
+ models = collect_models(
427
+ custom_models=custom_models,
428
+ inference_device=inference_device,
429
+ inference_dtype=inference_dtype,
430
+ destination_model_dir=destination_model_dir,
431
+ source=model_download_source,
432
+ )
433
+
434
+ pbar = tqdm(
435
+ total=len(scene_paths),
436
+ desc=f"Running inference using {inference_device.type} {str(inference_dtype).split('.')[-1]}",
437
+ )
438
+
439
+ for scene_path in scene_paths:
440
+ scene_path = Path(scene_path)
441
+ file_name = f"{scene_path.stem}_OCM_v{__version__.replace('.','_')}.tif"
442
+
443
+ if output_dir is None:
444
+ output_path = scene_path.parent / file_name
445
+ else:
446
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
447
+ output_path = Path(output_dir) / file_name
448
+
449
+ pred_paths.append(output_path)
450
+
451
+ if output_path.exists() and not overwrite:
452
+ pbar.update(1)
453
+ pbar.refresh()
454
+ continue
455
+
456
+ input_array, profile = load_func(input_path=scene_path)
457
+
458
+ while inf_thread.is_alive():
459
+ inf_thread.join()
460
+
461
+ inf_thread = Thread(
462
+ target=coordinator,
463
+ kwargs={
464
+ "input_array": input_array,
465
+ "profile": profile,
466
+ "output_path": output_path,
467
+ "models": models,
468
+ "inference_dtype": inference_dtype,
469
+ "export_confidence": export_confidence,
470
+ "softmax_output": softmax_output,
471
+ "inference_device": inference_device,
472
+ "mosaic_device": mosaic_device,
473
+ "patch_size": patch_size,
474
+ "patch_overlap": patch_overlap,
475
+ "batch_size": batch_size,
476
+ "no_data_value": no_data_value,
477
+ "pbar": pbar,
478
+ "apply_no_data_mask": apply_no_data_mask,
479
+ "save_executor": save_executor,
480
+ },
481
+ )
482
+ inf_thread.start()
483
+
484
+ while inf_thread.is_alive():
485
+ inf_thread.join()
486
+
487
+ if inference_device.type.startswith("cuda"):
488
+ torch.cuda.empty_cache()
489
+
490
+ save_executor.shutdown(wait=True)
491
+ pbar.refresh()
492
+
493
+ return pred_paths
omnicloudmask/data_loaders.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from pathlib import Path
3
+ from typing import Optional, Union
4
+
5
+ import numpy as np
6
+ import rasterio as rio
7
+ from rasterio.profiles import Profile
8
+
9
+
10
+ def load_s2(
11
+ input_path: Union[Path, str],
12
+ resolution: float = 10.0,
13
+ required_bands: list[str] = ["B04", "B03", "B8A"],
14
+ ) -> tuple[np.ndarray, Profile]:
15
+ """Load a Sentinel-2 (L1C or L2A) image from a SAFE folder containing the bands"""
16
+ if not 10 <= resolution <= 50:
17
+ raise ValueError("Resolution must be between 10 and 50")
18
+ input_path = Path(input_path)
19
+ processing_level = find_s2_processing_level(input_path)
20
+ return open_s2_bands(input_path, processing_level, resolution, required_bands)
21
+
22
+
23
+ def find_s2_processing_level(
24
+ input_path: Path,
25
+ ) -> str:
26
+ """Derive the processing level of a Sentinel-2 image from the folder name."""
27
+
28
+ folder_name = Path(input_path).name
29
+ processing_level = folder_name.split("_")[1][3:6]
30
+
31
+ if processing_level not in ["L1C", "L2A"]:
32
+ raise ValueError(
33
+ f"Processing level {processing_level} not recognized, expected L1C or L2A"
34
+ )
35
+ return processing_level
36
+
37
+
38
+ def open_s2_bands(
39
+ input_path: Path,
40
+ processing_level: str,
41
+ resolution: float,
42
+ required_bands: list[str],
43
+ ) -> tuple[np.ndarray, Profile]:
44
+ bands = []
45
+ for band_name in required_bands:
46
+ if processing_level == "L1C":
47
+ try:
48
+ band = list(input_path.rglob(f"*IMG_DATA/*{band_name}.jp2"))[0]
49
+
50
+ except IndexError:
51
+ raise ValueError(f"Band {band_name} not found in {input_path}")
52
+ else:
53
+ band = None
54
+ for search_resolution in [10, 20, 60]:
55
+ band_paths = list(
56
+ input_path.rglob(f"*{band_name}_{search_resolution}m.jp2")
57
+ )
58
+ if band_paths:
59
+ band = band_paths[0]
60
+ break
61
+ if not band:
62
+ raise ValueError(f"Band {band_name} not found in {input_path}")
63
+
64
+ with rio.open(band) as src:
65
+ profile = src.profile
66
+ native_resolution = int(src.res[0])
67
+ scale_factor = native_resolution / resolution
68
+ if native_resolution == resolution:
69
+ bands.append(src.read(1))
70
+ else:
71
+ bands.append(
72
+ src.read(
73
+ 1,
74
+ out_shape=(
75
+ int(src.height * scale_factor),
76
+ int(src.width * scale_factor),
77
+ ),
78
+ )
79
+ )
80
+ profile["transform"] = rio.transform.from_origin( # type: ignore
81
+ profile["transform"][2],
82
+ profile["transform"][5],
83
+ resolution,
84
+ resolution,
85
+ )
86
+ data = np.array(bands)
87
+ profile["height"] = data.shape[1]
88
+ profile["width"] = data.shape[2]
89
+ return data, profile
90
+
91
+
92
+ def load_multiband(
93
+ input_path: Union[Path, str],
94
+ resample_res: Optional[float] = None,
95
+ band_order: Optional[list[int]] = None,
96
+ ) -> tuple[np.ndarray, Profile]:
97
+ """Load a multiband image and resample it to requested resolution."""
98
+ if band_order is None:
99
+ warnings.warn(
100
+ "No band order provided, using default [1, 2, 3] (RGN)", UserWarning
101
+ )
102
+ band_order = [1, 2, 3]
103
+ input_path = Path(input_path)
104
+
105
+ with rio.open(input_path) as src:
106
+ if resample_res:
107
+ current_res = src.res
108
+ desired_res = (resample_res, resample_res)
109
+ scale_factor = (
110
+ current_res[0] / desired_res[0],
111
+ current_res[1] / desired_res[1],
112
+ )
113
+ else:
114
+ scale_factor = (1, 1)
115
+
116
+ data = src.read(
117
+ band_order,
118
+ out_shape=(
119
+ len(band_order),
120
+ int(src.height * scale_factor[0]),
121
+ int(src.width * scale_factor[1]),
122
+ ),
123
+ resampling=rio.enums.Resampling.nearest, # type: ignore
124
+ )
125
+ profile = src.profile
126
+
127
+ return data, profile
128
+
129
+
130
+ def load_ls8(
131
+ input_path: Union[Path, str],
132
+ resolution: int = 30,
133
+ required_bands=["B4", "B3", "B5"],
134
+ ) -> tuple[np.ndarray, Profile]:
135
+ """Load a Landsat 8 image from a folder containing the bands"""
136
+ if resolution != 30:
137
+ raise ValueError("Resolution must be 30")
138
+
139
+ input_path = Path(input_path)
140
+
141
+ band_files = {}
142
+ for band_name in required_bands:
143
+ try:
144
+ band = list(input_path.rglob(f"*{band_name}.TIF"))[0]
145
+
146
+ except IndexError:
147
+ raise ValueError(f"Band {band_name} not found in {input_path}")
148
+ band_files[band_name] = band
149
+
150
+ data = []
151
+ profile = Profile()
152
+ for band_name in required_bands:
153
+ with rio.open(band_files[band_name]) as src:
154
+ if not profile:
155
+ profile = src.profile
156
+ data.append(src.read(1))
157
+
158
+ data = np.array(data)
159
+ return data, profile
omnicloudmask/download_models.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Union
3
+
4
+ import gdown
5
+ import pandas as pd
6
+ import torch
7
+ from huggingface_hub import hf_hub_download
8
+ from safetensors.torch import load_file
9
+
10
+
11
+ def download_file_from_google_drive(file_id: str, destination: Path) -> None:
12
+ """
13
+ Downloads a file from Google Drive and saves it at the given destination using gdown.
14
+
15
+ Args:
16
+ file_id (str): The ID of the file on Google Drive.
17
+ destination (Path): The local path where the file should be saved.
18
+ """
19
+ url = f"https://drive.google.com/uc?id={file_id}"
20
+ gdown.download(url, str(destination), quiet=False)
21
+
22
+
23
+ def download_file_from_hugging_face(destination: Path) -> None:
24
+ """
25
+ Downloads a file from Hugging Face and saves it at the given destination using hf_hub_download.
26
+ Loads the resulting safetensors file and saves it as a PyTorch model state for compatibility with the rest of the codebase.
27
+
28
+ Args:
29
+ file_id (str): The ID of the file on Hugging Face.
30
+ destination (Path): The local path where the file should be saved.
31
+ """
32
+ file_name = destination.stem
33
+ safetensor_path = hf_hub_download(
34
+ repo_id="NickWright/OmniCloudMask",
35
+ filename=f"{file_name}.safetensors",
36
+ force_download=True,
37
+ cache_dir=destination.parent,
38
+ )
39
+ model_state = load_file(safetensor_path)
40
+ torch.save(model_state, destination)
41
+
42
+
43
+ def download_file(file_id: str, destination: Path, source: str) -> None:
44
+ if source == "google_drive":
45
+ download_file_from_google_drive(file_id, destination)
46
+ elif source == "hugging_face":
47
+ download_file_from_hugging_face(destination)
48
+ else:
49
+ raise ValueError(
50
+ "Invalid source. Supported sources are 'google_drive' and 'hugging_face'."
51
+ )
52
+
53
+
54
+ def get_models(
55
+ force_download: bool = False,
56
+ model_dir: Union[str, Path, None] = None,
57
+ source: str = "google_drive",
58
+ ) -> list[dict]:
59
+ """
60
+ Downloads the model weights from Google Drive and saves them locally.
61
+
62
+ Args:
63
+ force_download (bool): Whether to force download the model weights even if they already exist locally.
64
+ model_dir (Union[str, Path, None]): The directory where the model weights should be saved.
65
+ source (str): The source from which the model weights should be downloaded. Currently, only "google_drive" or "hugging_face" are supported.
66
+ """
67
+
68
+ df = pd.read_csv(
69
+ Path(__file__).resolve().parent / "models/model_download_links.csv"
70
+ )
71
+ model_paths = []
72
+
73
+ for _, row in df.iterrows():
74
+ file_id = str(row["google_drive_id"])
75
+
76
+ if model_dir is not None:
77
+ model_dir = Path(model_dir)
78
+ else:
79
+ model_dir = Path(__file__).resolve().parent / "models"
80
+
81
+ model_dir.mkdir(exist_ok=True)
82
+ destination = model_dir / str(row["file_name"])
83
+ timm_model_name = row["timm_model_name"]
84
+
85
+ if not destination.exists() or force_download:
86
+ download_file(file_id=file_id, destination=destination, source=source)
87
+
88
+ elif destination.stat().st_size <= 1024 * 1024:
89
+ download_file(file_id=file_id, destination=destination, source=source)
90
+
91
+ model_paths.append({"Path": destination, "timm_model_name": timm_model_name})
92
+ return model_paths
omnicloudmask/model_utils.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from pathlib import Path
3
+ from typing import Optional, Union
4
+
5
+ import numpy as np
6
+ import timm
7
+ import torch
8
+ from fastai.vision.learner import create_unet_model
9
+
10
+
11
+ def get_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype:
12
+ """Return a torch.dtype from a string or torch.dtype."""
13
+ if isinstance(dtype, str):
14
+ dtype_mapping = {
15
+ "float16": torch.float16,
16
+ "half": torch.float16,
17
+ "fp16": torch.float16,
18
+ "float32": torch.float32,
19
+ "float": torch.float32,
20
+ "bfloat16": torch.bfloat16,
21
+ "bf16": torch.bfloat16,
22
+ }
23
+ try:
24
+ return dtype_mapping[dtype.lower()]
25
+ except KeyError:
26
+ raise ValueError(
27
+ f"Invalid dtype: {dtype}. Must be one of {list(dtype_mapping.keys())}"
28
+ )
29
+ elif isinstance(dtype, torch.dtype):
30
+ return dtype
31
+ else:
32
+ raise TypeError(
33
+ f"Expected dtype to be a str or torch.dtype, but got {type(dtype)}"
34
+ )
35
+
36
+
37
+ def create_gradient_mask(
38
+ patch_size: int, patch_overlap: int, device: torch.device, dtype: torch.dtype
39
+ ) -> torch.Tensor:
40
+ """Create a gradient mask for a given patch size and overlap."""
41
+ if patch_overlap > 0:
42
+ if patch_overlap * 2 > patch_size:
43
+ patch_overlap = patch_size // 2
44
+
45
+ gradient_strength = 1
46
+ gradient = (
47
+ torch.ones((patch_size, patch_size), dtype=torch.int, device=device)
48
+ * patch_overlap
49
+ )
50
+ gradient[:, :patch_overlap] = torch.tile(
51
+ torch.arange(1, patch_overlap + 1),
52
+ (patch_size, 1),
53
+ )
54
+ gradient[:, -patch_overlap:] = torch.tile(
55
+ torch.arange(patch_overlap, 0, -1),
56
+ (patch_size, 1),
57
+ )
58
+ gradient = gradient / patch_overlap
59
+ rotated_gradient = torch.rot90(gradient)
60
+ combined_gradient = rotated_gradient * gradient
61
+
62
+ combined_gradient = (combined_gradient * gradient_strength) + (
63
+ 1 - gradient_strength
64
+ )
65
+ else:
66
+ combined_gradient = torch.ones(
67
+ (patch_size, patch_size), dtype=torch.int, device=device
68
+ )
69
+ return combined_gradient.to(dtype)
70
+
71
+
72
+ def channel_norm(patch: np.ndarray, nodata_value: Optional[int] = 0) -> np.ndarray:
73
+ """Normalize each band of the input array by subtracting the nonzero mean and dividing
74
+ by the nonzero standard deviation then fill nodata values with 0."""
75
+ out_array = np.zeros(patch.shape).astype(np.float32)
76
+ for id, band in enumerate(patch):
77
+ # Mask for non-zero values
78
+ mask = band != nodata_value
79
+ # Check if there are any non-zero values
80
+ if np.any(mask):
81
+ mean = band[mask].mean()
82
+ std = band[mask].std()
83
+ if std == 0:
84
+ std = 1 # Prevent division by zero
85
+ # Normalize only non-zero values
86
+ out_array[id][mask] = (band[mask] - mean) / std
87
+ else:
88
+ continue
89
+ # Fill original nodata values with 0
90
+ out_array[id][~mask] = 0
91
+ return out_array
92
+
93
+
94
+ def store_results(
95
+ pred_batch: torch.Tensor,
96
+ index_batch: list[tuple],
97
+ pred_tracker: torch.Tensor,
98
+ gradient: torch.Tensor,
99
+ grad_tracker: Optional[torch.Tensor] = None,
100
+ ) -> None:
101
+ """Store the results of the model inference in the pred_tracker and grad_tracker tensors."""
102
+ # Store the predictions in the pred_tracker tensor
103
+ assert pred_batch.ndim == 4, "pred_batch must have 4 dimensions, (B, class, H, W)"
104
+ assert pred_batch.shape[0] == len(index_batch), "Batch size must match index_batch"
105
+ assert pred_batch.shape[1] == pred_tracker.shape[0], "Number of classes must match"
106
+ assert pred_batch.shape[2] == gradient.shape[0], "Height must match gradient"
107
+ assert pred_batch.shape[3] == gradient.shape[1], "Width must match gradient"
108
+
109
+ pred_batch *= gradient[None, None, :, :]
110
+
111
+ for pred, index in zip(pred_batch.to(pred_tracker.device), index_batch):
112
+ pred_tracker[:, index[0] : index[1], index[2] : index[3]] += pred
113
+ if grad_tracker is not None:
114
+ grad_tracker[index[0] : index[1], index[2] : index[3]] += gradient.to(
115
+ grad_tracker.device
116
+ )
117
+
118
+
119
+ def inference_and_store(
120
+ models: list[torch.nn.Module],
121
+ patch_batch: torch.Tensor,
122
+ index_batch: list[tuple],
123
+ pred_tracker: torch.Tensor,
124
+ gradient: torch.Tensor,
125
+ grad_tracker: Optional[torch.Tensor] = None,
126
+ ) -> None:
127
+ """Perform inference on the patch_batch and store the results in the pred_tracker and grad_tracker tensors."""
128
+ # pre-initialize the all_preds tensor to store the predictions from each model
129
+ all_preds = torch.zeros(
130
+ len(models),
131
+ patch_batch.shape[0],
132
+ pred_tracker.shape[0],
133
+ patch_batch.shape[2],
134
+ patch_batch.shape[3],
135
+ device=patch_batch.device,
136
+ dtype=patch_batch.dtype,
137
+ )
138
+ for index, model in enumerate(models):
139
+ with torch.no_grad():
140
+ all_preds[index] = model(patch_batch)
141
+
142
+ mean_preds = all_preds.mean(dim=0)
143
+
144
+ store_results(
145
+ pred_batch=mean_preds,
146
+ index_batch=index_batch,
147
+ pred_tracker=pred_tracker,
148
+ gradient=gradient,
149
+ grad_tracker=grad_tracker,
150
+ )
151
+
152
+
153
+ def default_device() -> torch.device:
154
+ """Return the default device for model inference"""
155
+ if torch.cuda.is_available():
156
+ return torch.device("cuda")
157
+ elif torch.backends.mps.is_available():
158
+ return torch.device("mps")
159
+ return torch.device("cpu")
160
+
161
+
162
+ def load_model(
163
+ model_path: Union[Path, str],
164
+ device: torch.device,
165
+ dtype: torch.dtype = torch.float32,
166
+ ) -> torch.nn.Module:
167
+ """Load a PyTorch model from a file and move it to the specified device and dtype."""
168
+ model_path = Path(model_path)
169
+ if not model_path.is_file():
170
+ raise FileNotFoundError(f"Model file not found at: {model_path}")
171
+
172
+ try:
173
+ model = torch.load(model_path, map_location="cpu")
174
+ except Exception as e:
175
+ raise RuntimeError(f"Error loading model: {e}")
176
+
177
+ model.eval()
178
+ return model.to(dtype).to(device)
179
+
180
+
181
+ def load_model_from_weights(
182
+ model_name: str,
183
+ weights_path: Union[Path, str],
184
+ device: torch.device,
185
+ dtype: torch.dtype = torch.float32,
186
+ in_chans: int = 3,
187
+ n_out: int = 4,
188
+ ) -> torch.nn.Module:
189
+ """Build Fastai DynamicUnet model from timm model and load weights from file"""
190
+ timm_model = partial(
191
+ timm.create_model,
192
+ model_name=model_name,
193
+ pretrained=False,
194
+ in_chans=in_chans,
195
+ )
196
+
197
+ model = create_unet_model(
198
+ arch=timm_model,
199
+ n_out=n_out,
200
+ img_size=(509, 509),
201
+ act_cls=torch.nn.Mish,
202
+ pretrained=False,
203
+ )
204
+
205
+ model.load_state_dict(torch.load(weights_path, weights_only=True))
206
+ model.eval()
207
+
208
+ return model.to(dtype).to(device)
omnicloudmask/models/PM_model_2.2.10_RG_NIR_509_convnextv2_nano.fcmae_ft_in1k_PT_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d83ddef55797fd443cb30fdd545edf4f070c76cfb031ab53af2cd01f51d6d0f
3
+ size 130226202
omnicloudmask/models/PM_model_2.2.10_RG_NIR_509_regnety_004.pycls_in1k_PT_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be9e29fa69464a286d40e71b5af10894cabe2a258a6f1de4e869500ae704c7bd
3
+ size 72458313
omnicloudmask/models/model_download_links.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ file_name,timm_model_name,google_drive_id
2
+ PM_model_2.2.10_RG_NIR_509_regnety_004.pycls_in1k_PT_state.pth,regnety_004,1tGJh9nnrH-apjmV70AcK8VtXnbBtRb67
3
+ PM_model_2.2.10_RG_NIR_509_convnextv2_nano.fcmae_ft_in1k_PT_state.pth,convnextv2_nano,1QXQ_oPhLKEowC9fxlZGLOACt8gCNMbWP
omnicloudmask/raster_utils.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ import rasterio as rio
6
+ from rasterio.profiles import Profile
7
+
8
+ from .model_utils import channel_norm
9
+
10
+
11
+ def get_patch(
12
+ input_array: np.ndarray,
13
+ index: tuple,
14
+ no_data_value: Optional[int] = 0,
15
+ ) -> tuple[Optional[np.ndarray], Optional[tuple[int, int, int, int]]]:
16
+ """Extract a patch from a 3D array and normalize it. If the patch is entirely nodata, return None.
17
+ If the patch contains nodata, try to move patches to reduce nodata regions in patches.
18
+ """
19
+ assert input_array.ndim == 3, "Input array must have 3 dimensions"
20
+
21
+ top, bottom, left, right = index
22
+ patch = input_array[:, top:bottom, left:right].astype(np.float32)
23
+
24
+ if patch.sum() == 0:
25
+ return None, None
26
+
27
+ if no_data_value is None:
28
+ if np.all(patch == no_data_value):
29
+ return None, None
30
+
31
+ if np.any(patch == 0):
32
+ max_bottom, max_right = input_array.shape[1:3]
33
+
34
+ if np.any(patch[:, 0, :]) or np.any(patch[:, -1, :]):
35
+ while not np.any(patch[:, 0, :]) and bottom < max_bottom: # check top row
36
+ patch = patch[:, 1:, :]
37
+ top += 1
38
+ bottom += 1
39
+
40
+ while not np.any(patch[:, -1, :]) and top > 0:
41
+ patch = patch[:, :-1, :]
42
+ bottom -= 1
43
+ top -= 1
44
+
45
+ # Both sides are not zero-filled
46
+ if np.any(patch[:, :, 0]) or np.any(patch[:, :, -1]):
47
+ while not np.any(patch[:, :, 0]) and right < max_right: # check left column
48
+ patch = patch[:, :, 1:]
49
+ left += 1
50
+ right += 1
51
+
52
+ while not np.any(patch[:, :, -1]) and left > 0: # check right column
53
+ patch = patch[:, :, :-1]
54
+ right -= 1
55
+ left -= 1
56
+ patch = input_array[:, top:bottom, left:right].astype(np.float32)
57
+ index = (top, bottom, left, right)
58
+
59
+ # trim index bottom and right to match patch shape
60
+ index = (top, top + patch.shape[1], left, left + patch.shape[2])
61
+ return channel_norm(patch, no_data_value), index
62
+
63
+
64
+ def mask_prediction(
65
+ scene: np.ndarray, pred_tracker_np: np.ndarray, no_data_value: int = 0
66
+ ) -> np.ndarray:
67
+ """Create a no data mask from a raster scene."""
68
+ assert scene.ndim == 3, "Scene must have 3 dimensions"
69
+ assert pred_tracker_np.ndim == 3, "Prediction tracker must have 3 dimensions"
70
+ assert (
71
+ scene.shape[1:] == pred_tracker_np.shape[1:]
72
+ ), "Scene and prediction tracker must have the same shape"
73
+ mask = np.all(scene != no_data_value, axis=0).astype(np.uint8)
74
+ pred_tracker_np *= mask
75
+ return pred_tracker_np
76
+
77
+
78
+ def make_patch_indexes(
79
+ array_width: int,
80
+ array_height: int,
81
+ patch_size: int = 1000,
82
+ patch_overlap: int = 300,
83
+ ) -> list[tuple[int, int, int, int]]:
84
+ """Create a list of patch indexes for a given shape and patch size."""
85
+ assert patch_size > patch_overlap, "Patch size must be greater than patch overlap"
86
+ assert patch_overlap >= 0, "Patch overlap must be greater than or equal to 0"
87
+ assert patch_size > 0, "Patch size must be greater than 0"
88
+ assert (
89
+ patch_size <= array_width
90
+ ), "Patch size must be less than or equal to array width"
91
+ assert (
92
+ patch_size <= array_height
93
+ ), "Patch size must be less than or equal to array height"
94
+
95
+ stride = patch_size - patch_overlap
96
+
97
+ max_bottom = array_height - patch_size
98
+ max_right = array_width - patch_size
99
+
100
+ patch_indexes = []
101
+ for top in range(0, array_height, stride):
102
+ if top > max_bottom:
103
+ top = max_bottom
104
+ bottom = top + patch_size
105
+ for left in range(0, array_width, stride):
106
+ if left > max_right:
107
+ left = max_right
108
+ right = left + patch_size
109
+ patch_indexes.append((top, bottom, left, right))
110
+
111
+ return patch_indexes
112
+
113
+
114
+ def save_prediction(
115
+ output_path: Path, export_profile: Profile, pred_tracker_np: np.ndarray
116
+ ) -> None:
117
+ with rio.open(output_path, "w", **export_profile) as dst:
118
+ dst.write(pred_tracker_np)