examples feature
Browse files
app.py
CHANGED
@@ -89,12 +89,12 @@ def download_model_weights():
|
|
89 |
try:
|
90 |
os.makedirs('weights', exist_ok=True)
|
91 |
local_path = os.path.join('weights', SEGMENTATION_MODEL_FILENAME)
|
92 |
-
|
93 |
# Check if weights are already downloaded
|
94 |
if os.path.exists(local_path):
|
95 |
print(f"Model weights already downloaded at {local_path}")
|
96 |
return local_path
|
97 |
-
|
98 |
# Download weights
|
99 |
print(f"Downloading model weights from {SEGMENTATION_MODEL_REPO}...")
|
100 |
url = f"https://huggingface.co/{SEGMENTATION_MODEL_REPO}/resolve/main/{SEGMENTATION_MODEL_FILENAME}"
|
@@ -159,14 +159,14 @@ def normalize(band):
|
|
159 |
band_cleaned = band[np.isfinite(band)]
|
160 |
if len(band_cleaned) == 0:
|
161 |
return band
|
162 |
-
|
163 |
# Use percentiles to avoid outliers
|
164 |
band_min, band_max = np.percentile(band_cleaned, (2, 98))
|
165 |
-
|
166 |
# Avoid division by zero
|
167 |
if band_max == band_min:
|
168 |
return np.zeros_like(band)
|
169 |
-
|
170 |
band_normalized = (band - band_min) / (band_max - band_min)
|
171 |
band_normalized = np.clip(band_normalized, 0, 1)
|
172 |
return band_normalized
|
@@ -175,19 +175,19 @@ def calculate_cv(band):
|
|
175 |
"""Calculate coefficient of variation (CV) for a band"""
|
176 |
# First normalize the band
|
177 |
band_normalized = normalize(band)
|
178 |
-
|
179 |
# Handle potential NaN or inf values
|
180 |
band_cleaned = band_normalized[np.isfinite(band_normalized)]
|
181 |
if len(band_cleaned) == 0:
|
182 |
return 0
|
183 |
-
|
184 |
# Get mean and std dev
|
185 |
mean = np.mean(band_cleaned)
|
186 |
-
|
187 |
# Guard against division by zero or very small means
|
188 |
if abs(mean) < 1e-10:
|
189 |
return 0
|
190 |
-
|
191 |
std = np.std(band_cleaned)
|
192 |
cv = (std / mean) # CV as ratio (not percentage)
|
193 |
return cv
|
@@ -205,14 +205,14 @@ def read_tiff_image_for_segmentation(tiff_path):
|
|
205 |
red = src.read(1)
|
206 |
green = src.read(2)
|
207 |
blue = src.read(3)
|
208 |
-
|
209 |
# Stack to create RGB image
|
210 |
image = np.dstack((red, green, blue)).astype(np.float32)
|
211 |
-
|
212 |
# Normalize to [0, 1]
|
213 |
if image.max() > 0:
|
214 |
image = image / image.max()
|
215 |
-
|
216 |
return image
|
217 |
else:
|
218 |
# If less than 3 bands, handle accordingly
|
@@ -225,11 +225,11 @@ def read_tiff_image_for_segmentation(tiff_path):
|
|
225 |
while len(bands) < 3:
|
226 |
bands.append(np.zeros_like(bands[0]))
|
227 |
image = np.dstack(bands[:3]) # Use first 3 bands
|
228 |
-
|
229 |
# Normalize
|
230 |
if image.max() > 0:
|
231 |
image = image / image.max()
|
232 |
-
|
233 |
return image
|
234 |
except Exception as e:
|
235 |
print(f"Error reading TIFF file for segmentation: {e}")
|
@@ -243,22 +243,22 @@ def extract_cloud_features_from_tiff(tiff_path):
|
|
243 |
try:
|
244 |
with rasterio.open(tiff_path) as src:
|
245 |
num_bands = min(src.count, 10) # Use up to 10 bands
|
246 |
-
|
247 |
# Process each band
|
248 |
features = {}
|
249 |
for i in range(1, num_bands + 1):
|
250 |
band = src.read(i)
|
251 |
-
|
252 |
# Calculate coefficient of variation
|
253 |
cv_value = calculate_cv(band)
|
254 |
-
|
255 |
# Store feature with name matching the training data
|
256 |
features[f'band{i}_cv'] = cv_value
|
257 |
-
|
258 |
# If we have fewer than 10 bands, fill the missing ones with zeros
|
259 |
for i in range(num_bands + 1, 11):
|
260 |
features[f'band{i}_cv'] = 0.0
|
261 |
-
|
262 |
return features
|
263 |
except Exception as e:
|
264 |
print(f"Error extracting cloud features from TIFF: {e}")
|
@@ -275,25 +275,26 @@ def extract_cloud_features_from_rgb(image):
|
|
275 |
# Make sure image is in float format in range [0,1]
|
276 |
if image.dtype != np.float32 and image.dtype != np.float64:
|
277 |
image = image.astype(np.float32)
|
278 |
-
|
279 |
if image.max() > 1.0:
|
280 |
image = image / 255.0
|
281 |
-
|
282 |
# Create a dictionary for band CV features
|
283 |
features = {}
|
284 |
-
|
285 |
-
# Process each channel/band
|
286 |
-
|
|
|
287 |
band = image[:, :, i]
|
288 |
cv_value = calculate_cv(band)
|
289 |
features[f'band{i+1}_cv'] = cv_value
|
290 |
-
|
291 |
# Fill remaining bands with zeros to match the expected 10 features
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
return features
|
296 |
-
|
297 |
except Exception as e:
|
298 |
print(f"Error extracting cloud features from RGB: {e}")
|
299 |
import traceback
|
@@ -304,21 +305,21 @@ def predict_cloud(features_dict, model):
|
|
304 |
"""Predict if an image is cloudy based on extracted features"""
|
305 |
if model is None:
|
306 |
return {'prediction': 'Model unavailable', 'probability': 0.0}
|
307 |
-
|
308 |
try:
|
309 |
# Ensure all 10 features from band1_cv to band10_cv are present
|
310 |
feature_dict = {}
|
311 |
for i in range(1, 11):
|
312 |
feature_name = f'band{i}_cv'
|
313 |
feature_dict[feature_name] = features_dict.get(feature_name, 0.0)
|
314 |
-
|
315 |
# Create a DataFrame with all required features
|
316 |
feature_df = pd.DataFrame([feature_dict])
|
317 |
-
|
318 |
# Enable shape check disabling for prediction
|
319 |
if hasattr(model, 'set_params'):
|
320 |
model.set_params(predict_disable_shape_check=True)
|
321 |
-
|
322 |
# Make prediction
|
323 |
if hasattr(model, 'predict_proba'):
|
324 |
proba = model.predict_proba(feature_df)
|
@@ -330,10 +331,10 @@ def predict_cloud(features_dict, model):
|
|
330 |
# If model doesn't have predict_proba, use predict and assume binary output
|
331 |
pred = model.predict(feature_df)
|
332 |
probability = float(pred[0])
|
333 |
-
|
334 |
# Classification based on probability threshold
|
335 |
prediction = 'Cloudy' if probability >= 0.5 else 'Non-Cloudy'
|
336 |
-
|
337 |
return {
|
338 |
'prediction': prediction,
|
339 |
'probability': probability
|
@@ -369,38 +370,42 @@ def preprocess_image(image, target_size=(128, 128)):
|
|
369 |
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
370 |
elif image.shape[2] == 4: # RGBA
|
371 |
image = image[:, :, :3]
|
372 |
-
|
373 |
# Make a copy for display
|
374 |
display_image = image.copy()
|
375 |
-
|
376 |
# Normalize to [0, 1] if needed
|
377 |
if display_image.max() > 1.0:
|
378 |
image = image.astype(np.float32) / 255.0
|
379 |
-
|
|
|
|
|
|
|
|
|
380 |
# Convert PIL image to numpy
|
381 |
elif isinstance(image, Image.Image):
|
382 |
image = np.array(image)
|
383 |
-
|
384 |
# Ensure RGB format
|
385 |
if len(image.shape) == 2: # Grayscale
|
386 |
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
387 |
elif image.shape[2] == 4: # RGBA
|
388 |
image = image[:, :, :3]
|
389 |
-
|
390 |
# Make a copy for display
|
391 |
-
display_image = image.copy()
|
392 |
-
|
393 |
-
# Normalize to [0, 1]
|
394 |
image = image.astype(np.float32) / 255.0
|
395 |
else:
|
396 |
print(f"Unsupported image type: {type(image)}")
|
397 |
return None, None
|
398 |
-
|
399 |
# Resize image to the target size
|
400 |
if albumentations_available:
|
401 |
# Use albumentations to match training preprocessing
|
402 |
aug = A.Compose([
|
403 |
-
A.PadIfNeeded(min_height=target_size[0], min_width=target_size[1],
|
404 |
border_mode=cv2.BORDER_CONSTANT, value=0),
|
405 |
A.CenterCrop(height=target_size[0], width=target_size[1])
|
406 |
])
|
@@ -409,160 +414,196 @@ def preprocess_image(image, target_size=(128, 128)):
|
|
409 |
else:
|
410 |
# Fallback to OpenCV
|
411 |
image_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
|
412 |
-
|
413 |
# Convert to tensor [C, H, W]
|
414 |
image_tensor = torch.from_numpy(image_resized.transpose(2, 0, 1)).float().unsqueeze(0)
|
415 |
-
|
416 |
return image_tensor, display_image
|
417 |
|
418 |
def extract_file_content(file_obj):
|
419 |
"""Extract content from the file object, handling different types"""
|
420 |
try:
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
if os.path.exists(
|
425 |
-
|
426 |
-
|
427 |
-
return f.read()
|
428 |
else:
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
return file_obj.read()
|
434 |
-
elif isinstance(file_obj, bytes):
|
435 |
-
# Already bytes
|
436 |
-
return file_obj
|
437 |
elif isinstance(file_obj, str):
|
438 |
-
# String path
|
439 |
if os.path.exists(file_obj):
|
440 |
-
|
441 |
-
return f.read()
|
442 |
else:
|
443 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
else:
|
445 |
print(f"Unsupported file object type: {type(file_obj)}")
|
446 |
-
return None
|
447 |
except Exception as e:
|
448 |
print(f"Error extracting file content: {e}")
|
449 |
-
|
|
|
|
|
450 |
|
451 |
def process_uploaded_tiff(file_obj):
|
452 |
"""Process an uploaded TIFF file for both segmentation and cloud detection"""
|
|
|
453 |
try:
|
454 |
-
# Get file content
|
455 |
-
|
456 |
-
|
457 |
-
|
|
|
|
|
458 |
return None, None, None
|
459 |
-
|
460 |
-
#
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
# Read as TIFF for segmentation (only using first 3 bands)
|
466 |
-
image_for_segmentation = read_tiff_image_for_segmentation(
|
467 |
-
|
468 |
# Extract cloud features from all available bands
|
469 |
-
cloud_features = extract_cloud_features_from_tiff(
|
470 |
-
|
471 |
-
# Clean up
|
472 |
-
os.unlink(temp_path)
|
473 |
-
|
474 |
if image_for_segmentation is None:
|
|
|
475 |
return None, None, None
|
476 |
-
|
477 |
-
# Make a copy for display
|
478 |
-
|
479 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
480 |
# Resize/preprocess for segmentation model
|
481 |
if albumentations_available:
|
482 |
aug = A.Compose([
|
483 |
-
A.PadIfNeeded(min_height=128, min_width=128,
|
484 |
border_mode=cv2.BORDER_CONSTANT, value=0),
|
485 |
A.CenterCrop(height=128, width=128)
|
486 |
])
|
487 |
-
|
|
|
|
|
|
|
|
|
|
|
488 |
image_resized = augmented['image']
|
489 |
else:
|
490 |
-
|
491 |
-
|
|
|
|
|
|
|
492 |
# Convert to tensor
|
493 |
image_tensor = torch.from_numpy(image_resized.transpose(2, 0, 1)).float().unsqueeze(0)
|
494 |
-
|
495 |
return image_tensor, display_image, cloud_features
|
496 |
-
|
497 |
except Exception as e:
|
498 |
print(f"Error processing uploaded TIFF: {e}")
|
499 |
import traceback
|
500 |
traceback.print_exc()
|
501 |
return None, None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
502 |
|
503 |
def process_uploaded_mask(file_obj):
|
504 |
"""Process an uploaded mask file"""
|
|
|
505 |
try:
|
506 |
-
# Get file content
|
507 |
-
|
508 |
-
if
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
#
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
if isinstance(file_name, str) and '.' in file_name:
|
517 |
-
suffix = '.' + file_name.split('.')[-1].lower()
|
518 |
-
|
519 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
520 |
-
temp_path = temp_file.name
|
521 |
-
temp_file.write(file_content)
|
522 |
-
|
523 |
# Check if it's a TIFF file
|
524 |
-
if
|
525 |
-
mask = read_tiff_mask(
|
526 |
else:
|
527 |
# Try to open as a regular image
|
528 |
try:
|
529 |
-
mask_img = Image.open(
|
530 |
mask = np.array(mask_img)
|
531 |
-
if len(mask.shape) == 3:
|
532 |
-
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
|
533 |
except Exception as e:
|
534 |
print(f"Error opening mask as regular image: {e}")
|
535 |
-
os.unlink(temp_path)
|
536 |
return None
|
537 |
-
|
538 |
-
# Clean up
|
539 |
-
os.unlink(temp_path)
|
540 |
-
|
541 |
if mask is None:
|
542 |
-
|
543 |
-
|
|
|
544 |
# Resize mask to 128x128
|
545 |
if albumentations_available:
|
546 |
aug = A.Compose([
|
547 |
-
A.PadIfNeeded(min_height=128, min_width=128,
|
548 |
border_mode=cv2.BORDER_CONSTANT, value=0),
|
549 |
A.CenterCrop(height=128, width=128)
|
550 |
])
|
551 |
-
augmented = aug(image=mask)
|
552 |
mask_resized = augmented['image']
|
553 |
else:
|
554 |
mask_resized = cv2.resize(mask, (128, 128), interpolation=cv2.INTER_NEAREST)
|
555 |
-
|
556 |
# Binarize the mask (0: background, 1: wetland)
|
557 |
-
|
558 |
-
|
|
|
|
|
|
|
|
|
559 |
return mask_binary
|
560 |
-
|
561 |
except Exception as e:
|
562 |
print(f"Error processing uploaded mask: {e}")
|
563 |
import traceback
|
564 |
traceback.print_exc()
|
565 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
566 |
|
567 |
def predict_segmentation(image_tensor):
|
568 |
"""
|
@@ -570,10 +611,10 @@ def predict_segmentation(image_tensor):
|
|
570 |
"""
|
571 |
try:
|
572 |
image_tensor = image_tensor.to(device)
|
573 |
-
|
574 |
with torch.no_grad():
|
575 |
output = model(image_tensor)
|
576 |
-
|
577 |
# Handle different model output formats
|
578 |
if isinstance(output, dict):
|
579 |
output = output['out']
|
@@ -581,7 +622,7 @@ def predict_segmentation(image_tensor):
|
|
581 |
pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
|
582 |
else: # Binary output (from smp models)
|
583 |
pred = (torch.sigmoid(output) > 0.5).squeeze().cpu().numpy().astype(np.uint8)
|
584 |
-
|
585 |
return pred
|
586 |
except Exception as e:
|
587 |
print(f"Error during prediction: {e}")
|
@@ -591,208 +632,288 @@ def calculate_metrics(pred_mask, gt_mask):
|
|
591 |
"""
|
592 |
Calculate evaluation metrics between prediction and ground truth
|
593 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
594 |
# Ensure binary masks
|
595 |
pred_binary = (pred_mask > 0).astype(np.uint8)
|
596 |
gt_binary = (gt_mask > 0).astype(np.uint8)
|
597 |
-
|
598 |
# Calculate intersection and union
|
599 |
intersection = np.logical_and(pred_binary, gt_binary).sum()
|
600 |
union = np.logical_or(pred_binary, gt_binary).sum()
|
601 |
-
|
602 |
# Calculate IoU
|
603 |
iou = intersection / union if union > 0 else 0
|
604 |
-
|
605 |
# Calculate precision and recall
|
606 |
true_positive = intersection
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
|
|
|
|
|
|
613 |
# Calculate F1 score
|
614 |
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
615 |
-
|
616 |
metrics = {
|
617 |
"Wetlands IoU": float(iou),
|
618 |
"Precision": float(precision),
|
619 |
"Recall": float(recall),
|
620 |
"F1 Score": float(f1)
|
621 |
}
|
622 |
-
|
623 |
return metrics
|
624 |
|
625 |
def process_images(input_image=None, input_tiff=None, gt_mask_file=None):
|
626 |
"""
|
627 |
Process input images, generate predictions for both wetland segmentation and cloud detection
|
628 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
629 |
try:
|
630 |
-
#
|
631 |
-
if
|
632 |
-
|
633 |
-
|
634 |
-
# Process the input and initialize cloud_features
|
635 |
-
cloud_features = None
|
636 |
-
|
637 |
-
if input_tiff is not None and input_tiff:
|
638 |
-
# Process uploaded TIFF file for both segmentation and cloud detection
|
639 |
image_tensor, display_image, cloud_features = process_uploaded_tiff(input_tiff)
|
640 |
if image_tensor is None:
|
641 |
return None, "Failed to process the input TIFF file."
|
|
|
642 |
elif input_image is not None:
|
643 |
-
|
644 |
image_tensor, display_image = preprocess_image(input_image)
|
645 |
if image_tensor is None:
|
646 |
return None, "Failed to process the input image."
|
647 |
-
|
648 |
-
|
649 |
-
cloud_features = extract_cloud_features_from_rgb(display_image)
|
|
|
650 |
else:
|
651 |
-
return None, "
|
652 |
-
|
|
|
653 |
# Get wetland segmentation prediction
|
|
|
654 |
pred_mask = predict_segmentation(image_tensor)
|
655 |
if pred_mask is None:
|
656 |
return None, "Failed to generate wetland segmentation prediction."
|
657 |
-
|
|
|
|
|
658 |
# Get cloud prediction
|
659 |
cloud_result = {'prediction': 'Unknown', 'probability': 0.0}
|
660 |
if cloud_features and cloud_model:
|
|
|
661 |
cloud_result = predict_cloud(cloud_features, cloud_model)
|
662 |
-
|
663 |
-
|
664 |
-
|
|
|
|
|
|
|
|
|
|
|
665 |
metrics_text = ""
|
666 |
-
|
667 |
-
|
668 |
gt_mask_processed = process_uploaded_mask(gt_mask_file)
|
669 |
-
|
670 |
if gt_mask_processed is not None:
|
|
|
|
|
671 |
metrics = calculate_metrics(pred_mask, gt_mask_processed)
|
672 |
metrics_text = "\n".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
plt.axis('off')
|
688 |
-
|
689 |
-
plt.subplot(1, 3, 3)
|
690 |
-
plt.imshow(pred_mask, cmap='binary')
|
691 |
-
plt.title("Prediction")
|
692 |
-
plt.axis('off')
|
693 |
else:
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
706 |
wetland_percentage = np.mean(pred_mask) * 100
|
707 |
-
|
708 |
-
#
|
709 |
-
result_text = f"
|
710 |
-
|
|
|
711 |
# Add cloud detection results
|
712 |
result_text += f"Cloud Detection: {cloud_result['prediction']} "
|
713 |
result_text += f"({cloud_result['probability']*100:.2f}% Cloud probability)\n\n"
|
714 |
-
|
715 |
# Add segmentation metrics if available
|
716 |
if metrics_text:
|
717 |
-
|
718 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
719 |
# Convert figure to image for display
|
720 |
-
plt.tight_layout()
|
721 |
buf = BytesIO()
|
722 |
-
plt.savefig(buf, format='png')
|
723 |
buf.seek(0)
|
724 |
result_image = Image.open(buf)
|
725 |
plt.close(fig)
|
726 |
-
|
|
|
727 |
return result_image, result_text
|
728 |
-
|
729 |
except Exception as e:
|
730 |
-
print(f"Error in
|
731 |
import traceback
|
732 |
traceback.print_exc()
|
733 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
734 |
|
735 |
-
# Create Gradio
|
736 |
with gr.Blocks(title="Wetlands Segmentation & Cloud Detection") as demo:
|
737 |
gr.Markdown("# Wetlands Segmentation & Cloud Detection from Satellite Imagery")
|
738 |
gr.Markdown("Upload a satellite image or TIFF file to identify wetland areas and detect cloud cover. Optionally, you can also upload a ground truth mask for evaluation.")
|
739 |
-
|
740 |
with gr.Row():
|
741 |
-
with gr.Column():
|
742 |
-
# Input options
|
743 |
gr.Markdown("### Input")
|
744 |
-
with gr.
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
# Ground truth mask as file upload
|
751 |
-
gt_mask_file = gr.File(label="Ground Truth Mask (Optional)", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"])
|
752 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
753 |
process_btn = gr.Button("Analyze Image", variant="primary")
|
754 |
-
|
755 |
-
with gr.Column():
|
756 |
-
# Output
|
757 |
gr.Markdown("### Results")
|
758 |
-
output_image = gr.Image(label="Segmentation Results", type="pil")
|
759 |
-
output_text = gr.Textbox(label="Statistics", lines=
|
760 |
-
|
761 |
# Information about the models
|
762 |
gr.Markdown("### About these models")
|
763 |
gr.Markdown("""
|
764 |
This application uses two models:
|
765 |
-
|
766 |
**1. Wetland Segmentation Model:**
|
767 |
- Architecture: DeepLabv3+ with ResNet-34
|
768 |
-
- Input: RGB satellite imagery
|
769 |
- Output: Binary segmentation mask (Wetland vs Background)
|
770 |
-
- Resolution: 128×128 pixels
|
771 |
-
|
772 |
**2. Cloud Detection Model:**
|
773 |
- Architecture: LightGBM Classifier
|
774 |
-
- Input: CV features extracted from up to 10 image bands
|
775 |
- Output: Binary classification (Cloudy vs Non-Cloudy) with probability
|
776 |
-
|
777 |
**Tips for best results:**
|
778 |
-
-
|
779 |
-
-
|
780 |
-
- The
|
781 |
-
-
|
782 |
-
- For
|
783 |
-
|
784 |
-
|
785 |
-
- For ground truth masks, both TIFF and standard image formats are supported
|
786 |
-
|
787 |
-
**Repository:** [dcrey7/wetlands_segmentation_deeplabsv3plus](https://huggingface.co/dcrey7/wetlands_segmentation_deeplabsv3plus)
|
788 |
""")
|
789 |
-
|
790 |
# Set up event handlers
|
791 |
process_btn.click(
|
792 |
fn=process_images,
|
793 |
inputs=[input_image, input_tiff, gt_mask_file],
|
794 |
-
outputs=[output_image, output_text]
|
|
|
795 |
)
|
796 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
797 |
# Launch the app
|
798 |
-
|
|
|
|
89 |
try:
|
90 |
os.makedirs('weights', exist_ok=True)
|
91 |
local_path = os.path.join('weights', SEGMENTATION_MODEL_FILENAME)
|
92 |
+
|
93 |
# Check if weights are already downloaded
|
94 |
if os.path.exists(local_path):
|
95 |
print(f"Model weights already downloaded at {local_path}")
|
96 |
return local_path
|
97 |
+
|
98 |
# Download weights
|
99 |
print(f"Downloading model weights from {SEGMENTATION_MODEL_REPO}...")
|
100 |
url = f"https://huggingface.co/{SEGMENTATION_MODEL_REPO}/resolve/main/{SEGMENTATION_MODEL_FILENAME}"
|
|
|
159 |
band_cleaned = band[np.isfinite(band)]
|
160 |
if len(band_cleaned) == 0:
|
161 |
return band
|
162 |
+
|
163 |
# Use percentiles to avoid outliers
|
164 |
band_min, band_max = np.percentile(band_cleaned, (2, 98))
|
165 |
+
|
166 |
# Avoid division by zero
|
167 |
if band_max == band_min:
|
168 |
return np.zeros_like(band)
|
169 |
+
|
170 |
band_normalized = (band - band_min) / (band_max - band_min)
|
171 |
band_normalized = np.clip(band_normalized, 0, 1)
|
172 |
return band_normalized
|
|
|
175 |
"""Calculate coefficient of variation (CV) for a band"""
|
176 |
# First normalize the band
|
177 |
band_normalized = normalize(band)
|
178 |
+
|
179 |
# Handle potential NaN or inf values
|
180 |
band_cleaned = band_normalized[np.isfinite(band_normalized)]
|
181 |
if len(band_cleaned) == 0:
|
182 |
return 0
|
183 |
+
|
184 |
# Get mean and std dev
|
185 |
mean = np.mean(band_cleaned)
|
186 |
+
|
187 |
# Guard against division by zero or very small means
|
188 |
if abs(mean) < 1e-10:
|
189 |
return 0
|
190 |
+
|
191 |
std = np.std(band_cleaned)
|
192 |
cv = (std / mean) # CV as ratio (not percentage)
|
193 |
return cv
|
|
|
205 |
red = src.read(1)
|
206 |
green = src.read(2)
|
207 |
blue = src.read(3)
|
208 |
+
|
209 |
# Stack to create RGB image
|
210 |
image = np.dstack((red, green, blue)).astype(np.float32)
|
211 |
+
|
212 |
# Normalize to [0, 1]
|
213 |
if image.max() > 0:
|
214 |
image = image / image.max()
|
215 |
+
|
216 |
return image
|
217 |
else:
|
218 |
# If less than 3 bands, handle accordingly
|
|
|
225 |
while len(bands) < 3:
|
226 |
bands.append(np.zeros_like(bands[0]))
|
227 |
image = np.dstack(bands[:3]) # Use first 3 bands
|
228 |
+
|
229 |
# Normalize
|
230 |
if image.max() > 0:
|
231 |
image = image / image.max()
|
232 |
+
|
233 |
return image
|
234 |
except Exception as e:
|
235 |
print(f"Error reading TIFF file for segmentation: {e}")
|
|
|
243 |
try:
|
244 |
with rasterio.open(tiff_path) as src:
|
245 |
num_bands = min(src.count, 10) # Use up to 10 bands
|
246 |
+
|
247 |
# Process each band
|
248 |
features = {}
|
249 |
for i in range(1, num_bands + 1):
|
250 |
band = src.read(i)
|
251 |
+
|
252 |
# Calculate coefficient of variation
|
253 |
cv_value = calculate_cv(band)
|
254 |
+
|
255 |
# Store feature with name matching the training data
|
256 |
features[f'band{i}_cv'] = cv_value
|
257 |
+
|
258 |
# If we have fewer than 10 bands, fill the missing ones with zeros
|
259 |
for i in range(num_bands + 1, 11):
|
260 |
features[f'band{i}_cv'] = 0.0
|
261 |
+
|
262 |
return features
|
263 |
except Exception as e:
|
264 |
print(f"Error extracting cloud features from TIFF: {e}")
|
|
|
275 |
# Make sure image is in float format in range [0,1]
|
276 |
if image.dtype != np.float32 and image.dtype != np.float64:
|
277 |
image = image.astype(np.float32)
|
278 |
+
|
279 |
if image.max() > 1.0:
|
280 |
image = image / 255.0
|
281 |
+
|
282 |
# Create a dictionary for band CV features
|
283 |
features = {}
|
284 |
+
|
285 |
+
# Process each channel/band (assuming image is H, W, C)
|
286 |
+
num_bands = min(3, image.shape[2])
|
287 |
+
for i in range(num_bands):
|
288 |
band = image[:, :, i]
|
289 |
cv_value = calculate_cv(band)
|
290 |
features[f'band{i+1}_cv'] = cv_value
|
291 |
+
|
292 |
# Fill remaining bands with zeros to match the expected 10 features
|
293 |
+
for i in range(num_bands + 1, 11):
|
294 |
+
features[f'band{i}_cv'] = 0.0
|
295 |
+
|
296 |
return features
|
297 |
+
|
298 |
except Exception as e:
|
299 |
print(f"Error extracting cloud features from RGB: {e}")
|
300 |
import traceback
|
|
|
305 |
"""Predict if an image is cloudy based on extracted features"""
|
306 |
if model is None:
|
307 |
return {'prediction': 'Model unavailable', 'probability': 0.0}
|
308 |
+
|
309 |
try:
|
310 |
# Ensure all 10 features from band1_cv to band10_cv are present
|
311 |
feature_dict = {}
|
312 |
for i in range(1, 11):
|
313 |
feature_name = f'band{i}_cv'
|
314 |
feature_dict[feature_name] = features_dict.get(feature_name, 0.0)
|
315 |
+
|
316 |
# Create a DataFrame with all required features
|
317 |
feature_df = pd.DataFrame([feature_dict])
|
318 |
+
|
319 |
# Enable shape check disabling for prediction
|
320 |
if hasattr(model, 'set_params'):
|
321 |
model.set_params(predict_disable_shape_check=True)
|
322 |
+
|
323 |
# Make prediction
|
324 |
if hasattr(model, 'predict_proba'):
|
325 |
proba = model.predict_proba(feature_df)
|
|
|
331 |
# If model doesn't have predict_proba, use predict and assume binary output
|
332 |
pred = model.predict(feature_df)
|
333 |
probability = float(pred[0])
|
334 |
+
|
335 |
# Classification based on probability threshold
|
336 |
prediction = 'Cloudy' if probability >= 0.5 else 'Non-Cloudy'
|
337 |
+
|
338 |
return {
|
339 |
'prediction': prediction,
|
340 |
'probability': probability
|
|
|
370 |
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
371 |
elif image.shape[2] == 4: # RGBA
|
372 |
image = image[:, :, :3]
|
373 |
+
|
374 |
# Make a copy for display
|
375 |
display_image = image.copy()
|
376 |
+
|
377 |
# Normalize to [0, 1] if needed
|
378 |
if display_image.max() > 1.0:
|
379 |
image = image.astype(np.float32) / 255.0
|
380 |
+
display_image = display_image.astype(np.uint8) # Keep display image as uint8
|
381 |
+
else:
|
382 |
+
# If already normalized, scale up for display
|
383 |
+
display_image = (display_image * 255).astype(np.uint8)
|
384 |
+
|
385 |
# Convert PIL image to numpy
|
386 |
elif isinstance(image, Image.Image):
|
387 |
image = np.array(image)
|
388 |
+
|
389 |
# Ensure RGB format
|
390 |
if len(image.shape) == 2: # Grayscale
|
391 |
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
392 |
elif image.shape[2] == 4: # RGBA
|
393 |
image = image[:, :, :3]
|
394 |
+
|
395 |
# Make a copy for display
|
396 |
+
display_image = image.copy() # display_image is uint8 here
|
397 |
+
|
398 |
+
# Normalize to [0, 1] for model
|
399 |
image = image.astype(np.float32) / 255.0
|
400 |
else:
|
401 |
print(f"Unsupported image type: {type(image)}")
|
402 |
return None, None
|
403 |
+
|
404 |
# Resize image to the target size
|
405 |
if albumentations_available:
|
406 |
# Use albumentations to match training preprocessing
|
407 |
aug = A.Compose([
|
408 |
+
A.PadIfNeeded(min_height=target_size[0], min_width=target_size[1],
|
409 |
border_mode=cv2.BORDER_CONSTANT, value=0),
|
410 |
A.CenterCrop(height=target_size[0], width=target_size[1])
|
411 |
])
|
|
|
414 |
else:
|
415 |
# Fallback to OpenCV
|
416 |
image_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
|
417 |
+
|
418 |
# Convert to tensor [C, H, W]
|
419 |
image_tensor = torch.from_numpy(image_resized.transpose(2, 0, 1)).float().unsqueeze(0)
|
420 |
+
|
421 |
return image_tensor, display_image
|
422 |
|
423 |
def extract_file_content(file_obj):
|
424 |
"""Extract content from the file object, handling different types"""
|
425 |
try:
|
426 |
+
# Handle Gradio File object (which has 'name' attribute pointing to temp path)
|
427 |
+
if hasattr(file_obj, 'name') and isinstance(file_obj.name, str):
|
428 |
+
file_path = file_obj.name
|
429 |
+
if os.path.exists(file_path):
|
430 |
+
with open(file_path, 'rb') as f:
|
431 |
+
return f.read(), file_path # Return path for TIFF reading
|
|
|
432 |
else:
|
433 |
+
print(f"Temp file path does not exist: {file_path}")
|
434 |
+
return None, None
|
435 |
+
|
436 |
+
# Handle string path (from gr.Examples)
|
|
|
|
|
|
|
|
|
437 |
elif isinstance(file_obj, str):
|
|
|
438 |
if os.path.exists(file_obj):
|
439 |
+
with open(file_obj, 'rb') as f:
|
440 |
+
return f.read(), file_obj # Return path for TIFF reading
|
441 |
else:
|
442 |
+
print(f"Provided file path does not exist: {file_obj}")
|
443 |
+
return None, None
|
444 |
+
# Handle BytesIO or other file-like objects (less likely with gr.File/gr.Examples)
|
445 |
+
elif hasattr(file_obj, 'read'):
|
446 |
+
content = file_obj.read()
|
447 |
+
# Need to save to temp file to get a path for rasterio
|
448 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.tmp') as temp_f:
|
449 |
+
temp_path = temp_f.name
|
450 |
+
temp_f.write(content)
|
451 |
+
return content, temp_path # Return path
|
452 |
+
elif isinstance(file_obj, bytes):
|
453 |
+
# Need to save to temp file to get a path for rasterio
|
454 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.tmp') as temp_f:
|
455 |
+
temp_path = temp_f.name
|
456 |
+
temp_f.write(file_obj)
|
457 |
+
return file_obj, temp_path # Return path
|
458 |
else:
|
459 |
print(f"Unsupported file object type: {type(file_obj)}")
|
460 |
+
return None, None
|
461 |
except Exception as e:
|
462 |
print(f"Error extracting file content: {e}")
|
463 |
+
import traceback
|
464 |
+
traceback.print_exc()
|
465 |
+
return None, None
|
466 |
|
467 |
def process_uploaded_tiff(file_obj):
|
468 |
"""Process an uploaded TIFF file for both segmentation and cloud detection"""
|
469 |
+
temp_file_to_delete = None
|
470 |
try:
|
471 |
+
# Get file content and path
|
472 |
+
# We primarily need the path for rasterio
|
473 |
+
_, file_path = extract_file_content(file_obj)
|
474 |
+
|
475 |
+
if file_path is None:
|
476 |
+
print("Failed to get file path for TIFF processing")
|
477 |
return None, None, None
|
478 |
+
|
479 |
+
# Check if extract_file_content created a temp file we need to manage
|
480 |
+
if file_path.endswith('.tmp'):
|
481 |
+
temp_file_to_delete = file_path
|
482 |
+
|
|
|
483 |
# Read as TIFF for segmentation (only using first 3 bands)
|
484 |
+
image_for_segmentation = read_tiff_image_for_segmentation(file_path)
|
485 |
+
|
486 |
# Extract cloud features from all available bands
|
487 |
+
cloud_features = extract_cloud_features_from_tiff(file_path)
|
488 |
+
|
|
|
|
|
|
|
489 |
if image_for_segmentation is None:
|
490 |
+
print("Failed to read TIFF for segmentation.")
|
491 |
return None, None, None
|
492 |
+
|
493 |
+
# Make a copy for display, ensuring uint8 for display
|
494 |
+
if image_for_segmentation.max() <= 1.0 and image_for_segmentation.min() >= 0.0:
|
495 |
+
display_image = (image_for_segmentation * 255).astype(np.uint8)
|
496 |
+
elif image_for_segmentation.dtype == np.uint8:
|
497 |
+
display_image = image_for_segmentation.copy()
|
498 |
+
else:
|
499 |
+
# Attempt normalization for display if out of range
|
500 |
+
display_image = cv2.normalize(image_for_segmentation, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
|
501 |
+
print("Warning: TIFF image values outside [0,1], normalized for display.")
|
502 |
+
|
503 |
+
|
504 |
# Resize/preprocess for segmentation model
|
505 |
if albumentations_available:
|
506 |
aug = A.Compose([
|
507 |
+
A.PadIfNeeded(min_height=128, min_width=128,
|
508 |
border_mode=cv2.BORDER_CONSTANT, value=0),
|
509 |
A.CenterCrop(height=128, width=128)
|
510 |
])
|
511 |
+
# Need float image for albumentations
|
512 |
+
image_float = image_for_segmentation.astype(np.float32)
|
513 |
+
if image_float.max() > 1.0: # Normalize if not already done
|
514 |
+
image_float = image_float / (image_float.max() + 1e-6)
|
515 |
+
|
516 |
+
augmented = aug(image=image_float)
|
517 |
image_resized = augmented['image']
|
518 |
else:
|
519 |
+
image_float = image_for_segmentation.astype(np.float32)
|
520 |
+
if image_float.max() > 1.0: # Normalize if not already done
|
521 |
+
image_float = image_float / (image_float.max() + 1e-6)
|
522 |
+
image_resized = cv2.resize(image_float, (128, 128), interpolation=cv2.INTER_LINEAR)
|
523 |
+
|
524 |
# Convert to tensor
|
525 |
image_tensor = torch.from_numpy(image_resized.transpose(2, 0, 1)).float().unsqueeze(0)
|
526 |
+
|
527 |
return image_tensor, display_image, cloud_features
|
528 |
+
|
529 |
except Exception as e:
|
530 |
print(f"Error processing uploaded TIFF: {e}")
|
531 |
import traceback
|
532 |
traceback.print_exc()
|
533 |
return None, None, None
|
534 |
+
finally:
|
535 |
+
# Clean up temporary file if created by extract_file_content
|
536 |
+
if temp_file_to_delete and os.path.exists(temp_file_to_delete):
|
537 |
+
try:
|
538 |
+
os.unlink(temp_file_to_delete)
|
539 |
+
except Exception as e_del:
|
540 |
+
print(f"Error deleting temp file {temp_file_to_delete}: {e_del}")
|
541 |
+
|
542 |
|
543 |
def process_uploaded_mask(file_obj):
|
544 |
"""Process an uploaded mask file"""
|
545 |
+
temp_file_to_delete = None
|
546 |
try:
|
547 |
+
# Get file content and path
|
548 |
+
_, file_path = extract_file_content(file_obj)
|
549 |
+
if file_path is None:
|
550 |
+
print("Failed to get file path for Mask processing")
|
551 |
+
return None
|
552 |
+
|
553 |
+
# Check if extract_file_content created a temp file we need to manage
|
554 |
+
if file_path.endswith('.tmp'):
|
555 |
+
temp_file_to_delete = file_path
|
556 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
557 |
# Check if it's a TIFF file
|
558 |
+
if file_path.lower().endswith(('.tif', '.tiff')):
|
559 |
+
mask = read_tiff_mask(file_path)
|
560 |
else:
|
561 |
# Try to open as a regular image
|
562 |
try:
|
563 |
+
mask_img = Image.open(file_path).convert('L') # Convert to grayscale
|
564 |
mask = np.array(mask_img)
|
|
|
|
|
565 |
except Exception as e:
|
566 |
print(f"Error opening mask as regular image: {e}")
|
|
|
567 |
return None
|
568 |
+
|
|
|
|
|
|
|
569 |
if mask is None:
|
570 |
+
print("Failed to read mask data.")
|
571 |
+
return None
|
572 |
+
|
573 |
# Resize mask to 128x128
|
574 |
if albumentations_available:
|
575 |
aug = A.Compose([
|
576 |
+
A.PadIfNeeded(min_height=128, min_width=128,
|
577 |
border_mode=cv2.BORDER_CONSTANT, value=0),
|
578 |
A.CenterCrop(height=128, width=128)
|
579 |
])
|
580 |
+
augmented = aug(image=mask) # Use 'image' key even for mask
|
581 |
mask_resized = augmented['image']
|
582 |
else:
|
583 |
mask_resized = cv2.resize(mask, (128, 128), interpolation=cv2.INTER_NEAREST)
|
584 |
+
|
585 |
# Binarize the mask (0: background, 1: wetland)
|
586 |
+
# Use a threshold (e.g., 127 for typical grayscale) or >0 if it's already somewhat binary
|
587 |
+
threshold = np.median(mask_resized[mask_resized > 0]) if np.any(mask_resized > 0) else 1
|
588 |
+
mask_binary = (mask_resized >= threshold).astype(np.uint8)
|
589 |
+
print(f"Mask binarized. Original min/max: {mask.min()}/{mask.max()}, Resized min/max: {mask_resized.min()}/{mask_resized.max()}, Binary sum: {mask_binary.sum()}")
|
590 |
+
|
591 |
+
|
592 |
return mask_binary
|
593 |
+
|
594 |
except Exception as e:
|
595 |
print(f"Error processing uploaded mask: {e}")
|
596 |
import traceback
|
597 |
traceback.print_exc()
|
598 |
return None
|
599 |
+
finally:
|
600 |
+
# Clean up temporary file if created by extract_file_content
|
601 |
+
if temp_file_to_delete and os.path.exists(temp_file_to_delete):
|
602 |
+
try:
|
603 |
+
os.unlink(temp_file_to_delete)
|
604 |
+
except Exception as e_del:
|
605 |
+
print(f"Error deleting temp mask file {temp_file_to_delete}: {e_del}")
|
606 |
+
|
607 |
|
608 |
def predict_segmentation(image_tensor):
|
609 |
"""
|
|
|
611 |
"""
|
612 |
try:
|
613 |
image_tensor = image_tensor.to(device)
|
614 |
+
|
615 |
with torch.no_grad():
|
616 |
output = model(image_tensor)
|
617 |
+
|
618 |
# Handle different model output formats
|
619 |
if isinstance(output, dict):
|
620 |
output = output['out']
|
|
|
622 |
pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
|
623 |
else: # Binary output (from smp models)
|
624 |
pred = (torch.sigmoid(output) > 0.5).squeeze().cpu().numpy().astype(np.uint8)
|
625 |
+
|
626 |
return pred
|
627 |
except Exception as e:
|
628 |
print(f"Error during prediction: {e}")
|
|
|
632 |
"""
|
633 |
Calculate evaluation metrics between prediction and ground truth
|
634 |
"""
|
635 |
+
if pred_mask is None or gt_mask is None:
|
636 |
+
print("Cannot calculate metrics: Invalid masks provided.")
|
637 |
+
return {}
|
638 |
+
if pred_mask.shape != gt_mask.shape:
|
639 |
+
print(f"Cannot calculate metrics: Shape mismatch - Pred {pred_mask.shape}, GT {gt_mask.shape}")
|
640 |
+
# Optionally resize one to match the other, e.g., resize GT to Pred shape
|
641 |
+
gt_mask = cv2.resize(gt_mask, (pred_mask.shape[1], pred_mask.shape[0]), interpolation=cv2.INTER_NEAREST)
|
642 |
+
print(f"Resized GT mask to {gt_mask.shape}")
|
643 |
+
|
644 |
+
|
645 |
# Ensure binary masks
|
646 |
pred_binary = (pred_mask > 0).astype(np.uint8)
|
647 |
gt_binary = (gt_mask > 0).astype(np.uint8)
|
648 |
+
|
649 |
# Calculate intersection and union
|
650 |
intersection = np.logical_and(pred_binary, gt_binary).sum()
|
651 |
union = np.logical_or(pred_binary, gt_binary).sum()
|
652 |
+
|
653 |
# Calculate IoU
|
654 |
iou = intersection / union if union > 0 else 0
|
655 |
+
|
656 |
# Calculate precision and recall
|
657 |
true_positive = intersection
|
658 |
+
pred_positive_count = pred_binary.sum()
|
659 |
+
gt_positive_count = gt_binary.sum()
|
660 |
+
|
661 |
+
false_positive = pred_positive_count - true_positive
|
662 |
+
false_negative = gt_positive_count - true_positive
|
663 |
+
|
664 |
+
precision = true_positive / pred_positive_count if pred_positive_count > 0 else 0
|
665 |
+
recall = true_positive / gt_positive_count if gt_positive_count > 0 else 0
|
666 |
+
|
667 |
# Calculate F1 score
|
668 |
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
669 |
+
|
670 |
metrics = {
|
671 |
"Wetlands IoU": float(iou),
|
672 |
"Precision": float(precision),
|
673 |
"Recall": float(recall),
|
674 |
"F1 Score": float(f1)
|
675 |
}
|
676 |
+
|
677 |
return metrics
|
678 |
|
679 |
def process_images(input_image=None, input_tiff=None, gt_mask_file=None):
|
680 |
"""
|
681 |
Process input images, generate predictions for both wetland segmentation and cloud detection
|
682 |
"""
|
683 |
+
image_tensor = None
|
684 |
+
display_image = None
|
685 |
+
cloud_features = None
|
686 |
+
pred_mask = None
|
687 |
+
gt_mask_processed = None
|
688 |
+
result_image = None
|
689 |
+
result_text = "Processing..." # Initial message
|
690 |
+
|
691 |
try:
|
692 |
+
# Determine input type and process
|
693 |
+
if input_tiff is not None:
|
694 |
+
print("Processing TIFF input...")
|
|
|
|
|
|
|
|
|
|
|
|
|
695 |
image_tensor, display_image, cloud_features = process_uploaded_tiff(input_tiff)
|
696 |
if image_tensor is None:
|
697 |
return None, "Failed to process the input TIFF file."
|
698 |
+
print("TIFF processing complete.")
|
699 |
elif input_image is not None:
|
700 |
+
print("Processing Image input...")
|
701 |
image_tensor, display_image = preprocess_image(input_image)
|
702 |
if image_tensor is None:
|
703 |
return None, "Failed to process the input image."
|
704 |
+
# For RGB images, extract cloud features separately
|
705 |
+
print("Extracting cloud features from RGB...")
|
706 |
+
cloud_features = extract_cloud_features_from_rgb(display_image) # Use display_image (uint8)
|
707 |
+
print("Image processing complete.")
|
708 |
else:
|
709 |
+
return None, "Please upload an image or TIFF file."
|
710 |
+
|
711 |
+
# --- Perform Predictions ---
|
712 |
# Get wetland segmentation prediction
|
713 |
+
print("Performing wetland segmentation...")
|
714 |
pred_mask = predict_segmentation(image_tensor)
|
715 |
if pred_mask is None:
|
716 |
return None, "Failed to generate wetland segmentation prediction."
|
717 |
+
print(f"Segmentation prediction generated, shape: {pred_mask.shape}, type: {pred_mask.dtype}")
|
718 |
+
|
719 |
+
|
720 |
# Get cloud prediction
|
721 |
cloud_result = {'prediction': 'Unknown', 'probability': 0.0}
|
722 |
if cloud_features and cloud_model:
|
723 |
+
print("Performing cloud detection...")
|
724 |
cloud_result = predict_cloud(cloud_features, cloud_model)
|
725 |
+
print(f"Cloud detection result: {cloud_result}")
|
726 |
+
elif cloud_model is None:
|
727 |
+
print("Cloud detection model not available.")
|
728 |
+
else:
|
729 |
+
print("Cloud features not extracted, skipping cloud detection.")
|
730 |
+
|
731 |
+
|
732 |
+
# --- Process Ground Truth and Metrics (if provided) ---
|
733 |
metrics_text = ""
|
734 |
+
if gt_mask_file is not None:
|
735 |
+
print("Processing ground truth mask...")
|
736 |
gt_mask_processed = process_uploaded_mask(gt_mask_file)
|
737 |
+
|
738 |
if gt_mask_processed is not None:
|
739 |
+
print(f"Ground truth mask processed, shape: {gt_mask_processed.shape}, type: {gt_mask_processed.dtype}")
|
740 |
+
print("Calculating metrics...")
|
741 |
metrics = calculate_metrics(pred_mask, gt_mask_processed)
|
742 |
metrics_text = "\n".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
|
743 |
+
print(f"Metrics calculated: {metrics_text}")
|
744 |
+
else:
|
745 |
+
print("Failed to process ground truth mask.")
|
746 |
+
metrics_text = "Ground truth mask provided but could not be processed."
|
747 |
+
|
748 |
+
|
749 |
+
# --- Create Visualization ---
|
750 |
+
print("Creating result visualization...")
|
751 |
+
fig, axes = plt.subplots(1, 3 if gt_mask_processed is not None else 2, figsize=(12, 5))
|
752 |
+
fig.suptitle("Analysis Results", fontsize=16)
|
753 |
+
|
754 |
+
# Ensure display_image is in correct format for imshow
|
755 |
+
if display_image.dtype != np.uint8:
|
756 |
+
display_image_for_plot = (display_image * 255).astype(np.uint8) if display_image.max() <=1 else display_image.astype(np.uint8)
|
|
|
|
|
|
|
|
|
|
|
|
|
757 |
else:
|
758 |
+
display_image_for_plot = display_image
|
759 |
+
|
760 |
+
# Resize display image to match mask size for consistent display if needed, or keep original?
|
761 |
+
# Let's keep original input size for clarity, masks are 128x128
|
762 |
+
display_image_resized = cv2.resize(display_image_for_plot, (512, 512), interpolation=cv2.INTER_LINEAR)
|
763 |
+
|
764 |
+
|
765 |
+
ax_idx = 0
|
766 |
+
axes[ax_idx].imshow(display_image_resized)
|
767 |
+
axes[ax_idx].set_title("Input Image (Resized for Display)")
|
768 |
+
axes[ax_idx].axis('off')
|
769 |
+
ax_idx += 1
|
770 |
+
|
771 |
+
if gt_mask_processed is not None:
|
772 |
+
axes[ax_idx].imshow(gt_mask_processed, cmap='viridis') # Use viridis for GT
|
773 |
+
axes[ax_idx].set_title("Ground Truth (128x128)")
|
774 |
+
axes[ax_idx].axis('off')
|
775 |
+
ax_idx += 1
|
776 |
+
|
777 |
+
# Ensure pred_mask is suitable for imshow (e.g., scale if needed, but should be 0/1)
|
778 |
+
axes[ax_idx].imshow(pred_mask, cmap='plasma') # Use plasma for Prediction
|
779 |
+
axes[ax_idx].set_title("Predicted Wetlands (128x128)")
|
780 |
+
axes[ax_idx].axis('off')
|
781 |
+
|
782 |
+
|
783 |
+
# Calculate wetland percentage from prediction
|
784 |
wetland_percentage = np.mean(pred_mask) * 100
|
785 |
+
|
786 |
+
# --- Format Output Text ---
|
787 |
+
result_text = f"--- Analysis Summary ---\n"
|
788 |
+
result_text += f"Wetland Coverage (Predicted): {wetland_percentage:.2f}%\n\n"
|
789 |
+
|
790 |
# Add cloud detection results
|
791 |
result_text += f"Cloud Detection: {cloud_result['prediction']} "
|
792 |
result_text += f"({cloud_result['probability']*100:.2f}% Cloud probability)\n\n"
|
793 |
+
|
794 |
# Add segmentation metrics if available
|
795 |
if metrics_text:
|
796 |
+
if "could not be processed" in metrics_text:
|
797 |
+
result_text += metrics_text # Add the error message
|
798 |
+
else:
|
799 |
+
result_text += f"--- Evaluation Metrics (vs Ground Truth) ---\n{metrics_text}"
|
800 |
+
elif gt_mask_file is not None:
|
801 |
+
result_text += "--- Evaluation Metrics ---\nGround truth provided but metrics could not be calculated."
|
802 |
+
|
803 |
+
|
804 |
# Convert figure to image for display
|
805 |
+
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap
|
806 |
buf = BytesIO()
|
807 |
+
plt.savefig(buf, format='png', bbox_inches='tight')
|
808 |
buf.seek(0)
|
809 |
result_image = Image.open(buf)
|
810 |
plt.close(fig)
|
811 |
+
print("Visualization complete.")
|
812 |
+
|
813 |
return result_image, result_text
|
814 |
+
|
815 |
except Exception as e:
|
816 |
+
print(f"Error in process_images: {e}")
|
817 |
import traceback
|
818 |
traceback.print_exc()
|
819 |
+
# Return current state if partial results exist, otherwise error
|
820 |
+
error_message = f"An error occurred during processing: {str(e)}"
|
821 |
+
return result_image if result_image else None, result_text + f"\n\nERROR: {error_message}"
|
822 |
+
|
823 |
+
|
824 |
+
# --- Define Example Files ---
|
825 |
+
# Ensure these paths are correct relative to where app.py is run
|
826 |
+
# These should be at the root of your Gradio Space repository
|
827 |
+
example_list = [
|
828 |
+
["test_p1_cloudy_input.tif", "test_p1_cloudy_output.tif"],
|
829 |
+
["test_p1_noncloudy_input.tif", "test_p1_noncloudy_output.tif"],
|
830 |
+
["test_p1_cloudy_input.tif", None], # Example without ground truth
|
831 |
+
["test_p1_noncloudy_input.tif", None], # Example without ground truth
|
832 |
+
]
|
833 |
|
834 |
+
# --- Create Gradio Interface ---
|
835 |
with gr.Blocks(title="Wetlands Segmentation & Cloud Detection") as demo:
|
836 |
gr.Markdown("# Wetlands Segmentation & Cloud Detection from Satellite Imagery")
|
837 |
gr.Markdown("Upload a satellite image or TIFF file to identify wetland areas and detect cloud cover. Optionally, you can also upload a ground truth mask for evaluation.")
|
838 |
+
|
839 |
with gr.Row():
|
840 |
+
with gr.Column(scale=1): # Input column
|
|
|
841 |
gr.Markdown("### Input")
|
842 |
+
with gr.Tabs():
|
843 |
+
with gr.TabItem("Upload Image"):
|
844 |
+
input_image = gr.Image(label="Upload Satellite Image (JPG, PNG etc.)", type="numpy")
|
845 |
+
with gr.TabItem("Upload TIFF"):
|
846 |
+
input_tiff = gr.File(label="Upload Multi-Band TIFF File", file_types=[".tif", ".tiff"], type="filepath") # Use filepath for easier handling
|
847 |
+
|
848 |
# Ground truth mask as file upload
|
849 |
+
gt_mask_file = gr.File(label="Ground Truth Mask (Optional)", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"], type="filepath")
|
850 |
+
|
851 |
+
# --- Add Examples Section ---
|
852 |
+
gr.Markdown("### Load Examples")
|
853 |
+
gr.Examples(
|
854 |
+
examples=example_list,
|
855 |
+
inputs=[input_tiff, gt_mask_file], # Corresponds to TIFF input and GT Mask input
|
856 |
+
label="Click an example below to load files:",
|
857 |
+
# Outputs are not needed here, examples just populate inputs
|
858 |
+
# examples_per_page=4 # Optional: control pagination
|
859 |
+
)
|
860 |
+
# --- End Examples Section ---
|
861 |
+
|
862 |
process_btn = gr.Button("Analyze Image", variant="primary")
|
863 |
+
|
864 |
+
with gr.Column(scale=2): # Output column (make it wider)
|
|
|
865 |
gr.Markdown("### Results")
|
866 |
+
output_image = gr.Image(label="Segmentation Results", type="pil", height=450) # Adjust height as needed
|
867 |
+
output_text = gr.Textbox(label="Statistics & Metrics", lines=10, scale=1) # Adjust lines and scale
|
868 |
+
|
869 |
# Information about the models
|
870 |
gr.Markdown("### About these models")
|
871 |
gr.Markdown("""
|
872 |
This application uses two models:
|
873 |
+
|
874 |
**1. Wetland Segmentation Model:**
|
875 |
- Architecture: DeepLabv3+ with ResNet-34
|
876 |
+
- Input: RGB satellite imagery (extracted from first 3 bands of TIFF if provided)
|
877 |
- Output: Binary segmentation mask (Wetland vs Background)
|
878 |
+
- Resolution: Processed at 128×128 pixels
|
879 |
+
|
880 |
**2. Cloud Detection Model:**
|
881 |
- Architecture: LightGBM Classifier
|
882 |
+
- Input: Coefficient of Variation (CV) features extracted from up to 10 image bands (from TIFF)
|
883 |
- Output: Binary classification (Cloudy vs Non-Cloudy) with probability
|
884 |
+
|
885 |
**Tips for best results:**
|
886 |
+
- Use the 'Upload TIFF' tab for multi-band satellite data to enable accurate cloud detection.
|
887 |
+
- The cloud detection model expects up to 10 bands. Performance may vary with fewer bands.
|
888 |
+
- The example files demonstrate cloudy/non-cloudy scenarios with corresponding ground truth.
|
889 |
+
- The models work best with images similar in characteristics to those used in training.
|
890 |
+
- For ground truth masks, both TIFF and standard image formats (PNG, JPG) are supported. Ensure the mask clearly delineates the target class.
|
891 |
+
|
892 |
+
**Repository:** [dcrey7/wetland_segmentation_deeplabsv3plus](https://huggingface.co/spaces/dcrey7/wetland_segmentation_deeplabsv3plus)
|
|
|
|
|
|
|
893 |
""")
|
894 |
+
|
895 |
# Set up event handlers
|
896 |
process_btn.click(
|
897 |
fn=process_images,
|
898 |
inputs=[input_image, input_tiff, gt_mask_file],
|
899 |
+
outputs=[output_image, output_text],
|
900 |
+
api_name="analyze" # Optional: name for API endpoint
|
901 |
)
|
902 |
|
903 |
+
# Clear inputs when switching tabs (optional but good UX)
|
904 |
+
def clear_other_input(selected_tab):
|
905 |
+
if selected_tab == 0: # "Upload Image" tab selected
|
906 |
+
return gr.update(value=None), gr.update() # Clear TIFF input
|
907 |
+
elif selected_tab == 1: # "Upload TIFF" tab selected
|
908 |
+
return gr.update(), gr.update(value=None) # Clear Image input
|
909 |
+
return gr.update(), gr.update() # Default case
|
910 |
+
|
911 |
+
# Assuming the Tabs component itself can be accessed or named,
|
912 |
+
# otherwise, this part might need adjustment based on Gradio's specific API for Tabs.
|
913 |
+
# If Tabs don't directly support change events easily, this part can be omitted.
|
914 |
+
# For now, let's assume tab switching doesn't automatically clear, user needs to manage inputs.
|
915 |
+
|
916 |
+
|
917 |
# Launch the app
|
918 |
+
if __name__ == "__main__":
|
919 |
+
demo.launch(debug=True) # Enable debug for more detailed logs during development
|