dcrey7 commited on
Commit
61fd66a
·
verified ·
1 Parent(s): c5b14be

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +362 -0
app.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from PIL import Image
7
+ import torchvision.transforms as transforms
8
+ import requests
9
+ import io
10
+ import matplotlib.colors as mcolors
11
+ import cv2
12
+ from io import BytesIO
13
+ import urllib.request
14
+ import tempfile
15
+ import rasterio
16
+ from rasterio.plot import reshape_as_image
17
+ import warnings
18
+ warnings.filterwarnings("ignore")
19
+
20
+ # Set device
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ print(f"Using device: {device}")
23
+
24
+ # Define the DeepLabv3+ model architecture
25
+ # This needs to match your trained model architecture
26
+ from torchvision.models.segmentation import deeplabv3_resnet50
27
+
28
+ # Initialize the model
29
+ model = deeplabv3_resnet50(pretrained=False, num_classes=2)
30
+
31
+ # Download model weights from HuggingFace
32
+ MODEL_REPO = "dcrey7/wetlands_segmentation_deeplabsv3plus"
33
+ MODEL_FILENAME = "DeepLabV3plus_best_model.pth"
34
+
35
+ def download_model_weights():
36
+ """Download model weights from HuggingFace repository"""
37
+ try:
38
+ os.makedirs('weights', exist_ok=True)
39
+ local_path = os.path.join('weights', MODEL_FILENAME)
40
+
41
+ # Check if weights are already downloaded
42
+ if os.path.exists(local_path):
43
+ print(f"Model weights already downloaded at {local_path}")
44
+ return local_path
45
+
46
+ # Download weights
47
+ print(f"Downloading model weights from {MODEL_REPO}...")
48
+ url = f"https://huggingface.co/{MODEL_REPO}/resolve/main/{MODEL_FILENAME}"
49
+ urllib.request.urlretrieve(url, local_path)
50
+ print(f"Model weights downloaded to {local_path}")
51
+ return local_path
52
+ except Exception as e:
53
+ print(f"Error downloading model weights: {e}")
54
+ return None
55
+
56
+ # Load the model weights
57
+ weights_path = download_model_weights()
58
+ if weights_path:
59
+ try:
60
+ model.load_state_dict(torch.load(weights_path, map_location=device))
61
+ print("Model weights loaded successfully")
62
+ except Exception as e:
63
+ print(f"Error loading model weights: {e}")
64
+ else:
65
+ print("No weights available. Model will not produce valid predictions.")
66
+
67
+ model.to(device)
68
+ model.eval()
69
+
70
+ def preprocess_image(image, target_size=(128, 128)):
71
+ """
72
+ Preprocess an image for inference
73
+ """
74
+ # Convert to numpy array if PIL image
75
+ if isinstance(image, Image.Image):
76
+ image = np.array(image)
77
+
78
+ # Ensure RGB format
79
+ if len(image.shape) == 2: # Grayscale
80
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
81
+ elif image.shape[2] == 4: # RGBA
82
+ image = image[:, :, :3]
83
+
84
+ # Resize image
85
+ image_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
86
+
87
+ # Normalize image to [0, 1]
88
+ image_normalized = image_resized.astype(np.float32)
89
+ if image_normalized.max() > 0:
90
+ image_normalized = image_normalized / image_normalized.max()
91
+
92
+ # Convert to tensor [C, H, W]
93
+ image_tensor = torch.from_numpy(image_normalized.transpose(2, 0, 1)).float().unsqueeze(0)
94
+ return image_tensor, image_resized
95
+
96
+ def preprocess_tiff(tiff_path, target_size=(128, 128)):
97
+ """
98
+ Preprocess a TIFF file for inference
99
+ """
100
+ try:
101
+ with rasterio.open(tiff_path) as src:
102
+ # Read RGB bands if available
103
+ if src.count >= 3:
104
+ red = src.read(1)
105
+ green = src.read(2)
106
+ blue = src.read(3)
107
+ image = np.dstack((red, green, blue))
108
+ else:
109
+ # If less than 3 bands, read all available bands
110
+ bands = [src.read(i+1) for i in range(src.count)]
111
+ # If only one band, duplicate to create RGB
112
+ if len(bands) == 1:
113
+ image = np.dstack((bands[0], bands[0], bands[0]))
114
+ else:
115
+ # Use available bands and pad with zeros if needed
116
+ while len(bands) < 3:
117
+ bands.append(np.zeros_like(bands[0]))
118
+ image = np.dstack(bands[:3]) # Use first 3 bands
119
+
120
+ # Normalize image to [0, 1]
121
+ image = image.astype(np.float32)
122
+ if image.max() > 0:
123
+ image = image / image.max()
124
+
125
+ # Create a display image
126
+ display_image = (image * 255).astype(np.uint8)
127
+
128
+ # Resize image
129
+ image_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
130
+ display_resized = cv2.resize(display_image, target_size, interpolation=cv2.INTER_LINEAR)
131
+
132
+ # Convert to tensor [C, H, W]
133
+ image_tensor = torch.from_numpy(image_resized.transpose(2, 0, 1)).float().unsqueeze(0)
134
+
135
+ return image_tensor, display_resized
136
+ except Exception as e:
137
+ print(f"Error processing TIFF: {e}")
138
+ return None, None
139
+
140
+ def preprocess_mask(mask, target_size=(128, 128)):
141
+ """
142
+ Preprocess a ground truth mask
143
+ """
144
+ # Convert to numpy array if PIL image
145
+ if isinstance(mask, Image.Image):
146
+ mask = np.array(mask)
147
+
148
+ # Convert to grayscale if needed
149
+ if len(mask.shape) == 3:
150
+ mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
151
+
152
+ # Resize mask
153
+ mask_resized = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
154
+
155
+ # Binarize the mask (0: background, 1: wetland)
156
+ mask_binary = (mask_resized > 127).astype(np.uint8)
157
+
158
+ return mask_binary
159
+
160
+ def predict_segmentation(image_tensor):
161
+ """
162
+ Run inference on the model
163
+ """
164
+ try:
165
+ image_tensor = image_tensor.to(device)
166
+
167
+ with torch.no_grad():
168
+ output = model(image_tensor)
169
+
170
+ # Extract the output based on model type
171
+ if isinstance(output, dict):
172
+ output = output['out']
173
+
174
+ # Get the predicted class (0: background, 1: wetland)
175
+ pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
176
+
177
+ return pred
178
+ except Exception as e:
179
+ print(f"Error during prediction: {e}")
180
+ return None
181
+
182
+ def calculate_metrics(pred_mask, gt_mask):
183
+ """
184
+ Calculate evaluation metrics between prediction and ground truth
185
+ """
186
+ # Ensure binary masks
187
+ pred_binary = (pred_mask > 0).astype(np.uint8)
188
+ gt_binary = (gt_mask > 0).astype(np.uint8)
189
+
190
+ # Calculate intersection and union
191
+ intersection = np.logical_and(pred_binary, gt_binary).sum()
192
+ union = np.logical_or(pred_binary, gt_binary).sum()
193
+
194
+ # Calculate IoU
195
+ iou = intersection / union if union > 0 else 0
196
+
197
+ # Calculate precision and recall
198
+ true_positive = intersection
199
+ false_positive = pred_binary.sum() - true_positive
200
+ false_negative = gt_binary.sum() - true_positive
201
+
202
+ precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0
203
+ recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
204
+
205
+ # Calculate F1 score
206
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
207
+
208
+ metrics = {
209
+ "IoU": float(iou),
210
+ "Precision": float(precision),
211
+ "Recall": float(recall),
212
+ "F1 Score": float(f1)
213
+ }
214
+
215
+ return metrics
216
+
217
+ def process_images(input_image=None, input_tiff=None, gt_mask=None):
218
+ """
219
+ Process input images and generate predictions
220
+ """
221
+ try:
222
+ # Check if we have input
223
+ if input_image is None and input_tiff is None:
224
+ return None, "Please upload an image or TIFF file."
225
+
226
+ # Process the input image
227
+ if input_tiff is not None:
228
+ # Save uploaded TIFF to a temporary file
229
+ with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as temp_tiff:
230
+ temp_tiff_path = temp_tiff.name
231
+ temp_tiff.write(input_tiff)
232
+
233
+ # Process TIFF file
234
+ image_tensor, display_image = preprocess_tiff(temp_tiff_path)
235
+ os.unlink(temp_tiff_path) # Clean up temp file
236
+ else:
237
+ # Process regular image
238
+ image_tensor, display_image = preprocess_image(input_image)
239
+
240
+ if image_tensor is None:
241
+ return None, "Failed to process the input image."
242
+
243
+ # Get prediction
244
+ pred_mask = predict_segmentation(image_tensor)
245
+ if pred_mask is None:
246
+ return None, "Failed to generate prediction."
247
+
248
+ # Process ground truth mask if provided
249
+ gt_mask_processed = None
250
+ metrics_text = ""
251
+
252
+ if gt_mask is not None:
253
+ gt_mask_processed = preprocess_mask(gt_mask)
254
+ metrics = calculate_metrics(pred_mask, gt_mask_processed)
255
+ metrics_text = "\n".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
256
+
257
+ # Create visualization
258
+ fig = plt.figure(figsize=(12, 6))
259
+
260
+ if gt_mask_processed is not None:
261
+ # Show original, ground truth, and prediction
262
+ plt.subplot(1, 3, 1)
263
+ plt.imshow(display_image)
264
+ plt.title("Input Image")
265
+ plt.axis('off')
266
+
267
+ plt.subplot(1, 3, 2)
268
+ plt.imshow(gt_mask_processed, cmap='binary')
269
+ plt.title("Ground Truth")
270
+ plt.axis('off')
271
+
272
+ plt.subplot(1, 3, 3)
273
+ plt.imshow(pred_mask, cmap='binary')
274
+ plt.title("Prediction")
275
+ plt.axis('off')
276
+ else:
277
+ # Show original and prediction
278
+ plt.subplot(1, 2, 1)
279
+ plt.imshow(display_image)
280
+ plt.title("Input Image")
281
+ plt.axis('off')
282
+
283
+ plt.subplot(1, 2, 2)
284
+ plt.imshow(pred_mask, cmap='binary')
285
+ plt.title("Predicted Wetlands")
286
+ plt.axis('off')
287
+
288
+ # Calculate wetland percentage
289
+ wetland_percentage = np.mean(pred_mask) * 100
290
+
291
+ # Add metrics info
292
+ result_text = f"Wetland Coverage: {wetland_percentage:.2f}%"
293
+ if metrics_text:
294
+ result_text += f"\n\nEvaluation Metrics:\n{metrics_text}"
295
+
296
+ # Convert figure to image
297
+ buf = BytesIO()
298
+ plt.tight_layout()
299
+ plt.savefig(buf, format='png', dpi=150)
300
+ buf.seek(0)
301
+ result_image = Image.open(buf)
302
+ plt.close(fig)
303
+
304
+ return result_image, result_text
305
+
306
+ except Exception as e:
307
+ print(f"Error in processing: {e}")
308
+ return None, f"Error: {str(e)}"
309
+
310
+ # Create Gradio interface
311
+ with gr.Blocks(title="Wetlands Segmentation from Satellite Imagery") as demo:
312
+ gr.Markdown("# Wetlands Segmentation from Satellite Imagery")
313
+ gr.Markdown("Upload a satellite image or TIFF file to identify wetland areas. Optionally, you can also upload a ground truth mask for evaluation.")
314
+
315
+ with gr.Row():
316
+ with gr.Column():
317
+ # Input options
318
+ gr.Markdown("### Input")
319
+ with gr.Tab("Upload Image"):
320
+ input_image = gr.Image(label="Upload Satellite Image", type="numpy")
321
+
322
+ with gr.Tab("Upload TIFF"):
323
+ input_tiff = gr.File(label="Upload TIFF File", file_types=[".tif", ".tiff"])
324
+
325
+ gt_mask = gr.Image(label="Ground Truth Mask (Optional)", type="numpy")
326
+
327
+ process_btn = gr.Button("Analyze Image", variant="primary")
328
+
329
+ with gr.Column():
330
+ # Output
331
+ gr.Markdown("### Results")
332
+ output_image = gr.Image(label="Segmentation Results", type="pil")
333
+ output_text = gr.Textbox(label="Statistics", lines=6)
334
+
335
+ # Information about the model
336
+ gr.Markdown("### About this model")
337
+ gr.Markdown("""
338
+ This application uses a DeepLabv3+ model trained to segment wetland areas in satellite imagery.
339
+
340
+ **Model Details:**
341
+ - Architecture: DeepLabv3+ with ResNet-50 backbone
342
+ - Input: RGB satellite imagery
343
+ - Output: Binary segmentation mask (Wetland vs Background)
344
+ - Resolution: 128×128 pixels
345
+
346
+ **Tips for best results:**
347
+ - The model works best with RGB satellite imagery
348
+ - For optimal results, use images with similar characteristics to those used in training
349
+ - The model focuses on identifying wetland regions in natural landscapes
350
+
351
+ **Repository:** [dcrey7/wetlands_segmentation_deeplabsv3plus](https://huggingface.co/dcrey7/wetlands_segmentation_deeplabsv3plus)
352
+ """)
353
+
354
+ # Set up event handlers
355
+ process_btn.click(
356
+ fn=process_images,
357
+ inputs=[input_image, input_tiff, gt_mask],
358
+ outputs=[output_image, output_text]
359
+ )
360
+
361
+ # Launch the app
362
+ demo.launch()