Update app.py
Browse files
app.py
CHANGED
|
@@ -12,7 +12,6 @@ from io import BytesIO
|
|
| 12 |
import urllib.request
|
| 13 |
import tempfile
|
| 14 |
import rasterio
|
| 15 |
-
from rasterio.plot import reshape_as_image
|
| 16 |
import warnings
|
| 17 |
warnings.filterwarnings("ignore")
|
| 18 |
|
|
@@ -69,7 +68,7 @@ if smp_available:
|
|
| 69 |
)
|
| 70 |
else:
|
| 71 |
# Fallback to a simple model that won't actually work but allows the UI to load
|
| 72 |
-
print("Warning: Using a placeholder model that won't produce
|
| 73 |
from torch import nn
|
| 74 |
class PlaceholderModel(nn.Module):
|
| 75 |
def __init__(self):
|
|
@@ -133,10 +132,8 @@ def read_tiff_image(tiff_path):
|
|
| 133 |
This matches your training data loading approach
|
| 134 |
"""
|
| 135 |
try:
|
| 136 |
-
print(f"Reading TIFF image from: {tiff_path}")
|
| 137 |
# Read the image using rasterio (get RGB channels)
|
| 138 |
with rasterio.open(tiff_path) as src:
|
| 139 |
-
print(f"TIFF opened successfully. Number of bands: {src.count}")
|
| 140 |
# Check if we have enough bands
|
| 141 |
if src.count >= 3:
|
| 142 |
red = src.read(1)
|
|
@@ -145,7 +142,6 @@ def read_tiff_image(tiff_path):
|
|
| 145 |
|
| 146 |
# Stack to create RGB image
|
| 147 |
image = np.dstack((red, green, blue)).astype(np.float32)
|
| 148 |
-
print(f"RGB image created, shape: {image.shape}, min: {image.min()}, max: {image.max()}")
|
| 149 |
|
| 150 |
# Normalize to [0, 1]
|
| 151 |
if image.max() > 0:
|
|
@@ -154,7 +150,6 @@ def read_tiff_image(tiff_path):
|
|
| 154 |
return image
|
| 155 |
else:
|
| 156 |
# If less than 3 bands, handle accordingly
|
| 157 |
-
print(f"Warning: TIFF file has only {src.count} bands, RGB expected")
|
| 158 |
bands = [src.read(i+1) for i in range(src.count)]
|
| 159 |
# If only one band, duplicate to create RGB
|
| 160 |
if len(bands) == 1:
|
|
@@ -172,8 +167,6 @@ def read_tiff_image(tiff_path):
|
|
| 172 |
return image
|
| 173 |
except Exception as e:
|
| 174 |
print(f"Error reading TIFF file: {e}")
|
| 175 |
-
import traceback
|
| 176 |
-
traceback.print_exc()
|
| 177 |
return None
|
| 178 |
|
| 179 |
def read_tiff_mask(mask_path):
|
|
@@ -182,17 +175,12 @@ def read_tiff_mask(mask_path):
|
|
| 182 |
This matches your training data loading approach
|
| 183 |
"""
|
| 184 |
try:
|
| 185 |
-
print(f"Reading TIFF mask from: {mask_path}")
|
| 186 |
# Read mask
|
| 187 |
with rasterio.open(mask_path) as src:
|
| 188 |
-
print(f"Mask TIFF opened successfully. Number of bands: {src.count}")
|
| 189 |
mask = src.read(1).astype(np.uint8)
|
| 190 |
-
print(f"Mask shape: {mask.shape}, min: {mask.min()}, max: {mask.max()}, unique values: {np.unique(mask)}")
|
| 191 |
return mask
|
| 192 |
except Exception as e:
|
| 193 |
print(f"Error reading mask file: {e}")
|
| 194 |
-
import traceback
|
| 195 |
-
traceback.print_exc()
|
| 196 |
return None
|
| 197 |
|
| 198 |
def preprocess_image(image, target_size=(128, 128)):
|
|
@@ -216,7 +204,6 @@ def preprocess_image(image, target_size=(128, 128)):
|
|
| 216 |
|
| 217 |
# Convert PIL image to numpy
|
| 218 |
elif isinstance(image, Image.Image):
|
| 219 |
-
print("Converting PIL image to numpy array")
|
| 220 |
image = np.array(image)
|
| 221 |
|
| 222 |
# Ensure RGB format
|
|
@@ -234,8 +221,6 @@ def preprocess_image(image, target_size=(128, 128)):
|
|
| 234 |
print(f"Unsupported image type: {type(image)}")
|
| 235 |
return None, None
|
| 236 |
|
| 237 |
-
print(f"Image shape: {image.shape}, min: {image.min()}, max: {image.max()}")
|
| 238 |
-
|
| 239 |
# Resize image to the target size
|
| 240 |
if albumentations_available:
|
| 241 |
# Use albumentations to match training preprocessing
|
|
@@ -255,59 +240,54 @@ def preprocess_image(image, target_size=(128, 128)):
|
|
| 255 |
|
| 256 |
return image_tensor, display_image
|
| 257 |
|
| 258 |
-
def
|
| 259 |
-
"""
|
| 260 |
-
Save uploaded file to a temporary file
|
| 261 |
-
"""
|
| 262 |
try:
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
if
|
| 267 |
-
#
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
#
|
| 272 |
-
content
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
# If it's a path
|
| 283 |
with open(file_obj, 'rb') as f:
|
| 284 |
-
|
| 285 |
-
print(f"Copied file from {file_obj} to {temp_path}")
|
| 286 |
else:
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
return temp_path
|
| 292 |
except Exception as e:
|
| 293 |
-
print(f"Error
|
| 294 |
-
import traceback
|
| 295 |
-
traceback.print_exc()
|
| 296 |
return None
|
| 297 |
|
| 298 |
def process_uploaded_tiff(file_obj):
|
| 299 |
-
"""
|
| 300 |
-
Process an uploaded TIFF file
|
| 301 |
-
"""
|
| 302 |
try:
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
# Save to a temporary file
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
|
| 310 |
-
# Read
|
| 311 |
image = read_tiff_image(temp_path)
|
| 312 |
|
| 313 |
# Clean up
|
|
@@ -343,16 +323,24 @@ def process_uploaded_tiff(file_obj):
|
|
| 343 |
return None, None
|
| 344 |
|
| 345 |
def process_uploaded_mask(file_obj):
|
| 346 |
-
"""
|
| 347 |
-
Process an uploaded mask file
|
| 348 |
-
"""
|
| 349 |
try:
|
| 350 |
-
|
|
|
|
|
|
|
|
|
|
| 351 |
|
| 352 |
# Save to a temporary file
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
# Check if it's a TIFF file
|
| 358 |
if temp_path.lower().endswith(('.tif', '.tiff')):
|
|
@@ -364,7 +352,6 @@ def process_uploaded_mask(file_obj):
|
|
| 364 |
mask = np.array(mask_img)
|
| 365 |
if len(mask.shape) == 3:
|
| 366 |
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
|
| 367 |
-
print(f"Opened mask as regular image, shape: {mask.shape}")
|
| 368 |
except Exception as e:
|
| 369 |
print(f"Error opening mask as regular image: {e}")
|
| 370 |
os.unlink(temp_path)
|
|
@@ -420,8 +407,6 @@ def predict_segmentation(image_tensor):
|
|
| 420 |
return pred
|
| 421 |
except Exception as e:
|
| 422 |
print(f"Error during prediction: {e}")
|
| 423 |
-
import traceback
|
| 424 |
-
traceback.print_exc()
|
| 425 |
return None
|
| 426 |
|
| 427 |
def calculate_metrics(pred_mask, gt_mask):
|
|
@@ -464,11 +449,6 @@ def process_images(input_image=None, input_tiff=None, gt_mask_file=None):
|
|
| 464 |
Process input images and generate predictions
|
| 465 |
"""
|
| 466 |
try:
|
| 467 |
-
print("\n---- Starting new processing request ----")
|
| 468 |
-
print(f"Input image type: {type(input_image) if input_image is not None else None}")
|
| 469 |
-
print(f"Input TIFF type: {type(input_tiff) if input_tiff is not None else None}")
|
| 470 |
-
print(f"Ground truth mask type: {type(gt_mask_file) if gt_mask_file is not None else None}")
|
| 471 |
-
|
| 472 |
# Check if we have input
|
| 473 |
if input_image is None and input_tiff is None:
|
| 474 |
return None, "Please upload an image or TIFF file."
|
|
@@ -488,7 +468,6 @@ def process_images(input_image=None, input_tiff=None, gt_mask_file=None):
|
|
| 488 |
return None, "No valid input provided."
|
| 489 |
|
| 490 |
# Get prediction
|
| 491 |
-
print("Running model prediction...")
|
| 492 |
pred_mask = predict_segmentation(image_tensor)
|
| 493 |
if pred_mask is None:
|
| 494 |
return None, "Failed to generate prediction."
|
|
@@ -498,19 +477,13 @@ def process_images(input_image=None, input_tiff=None, gt_mask_file=None):
|
|
| 498 |
metrics_text = ""
|
| 499 |
|
| 500 |
if gt_mask_file is not None and gt_mask_file:
|
| 501 |
-
print("Processing ground truth mask...")
|
| 502 |
gt_mask_processed = process_uploaded_mask(gt_mask_file)
|
| 503 |
|
| 504 |
if gt_mask_processed is not None:
|
| 505 |
-
print("Calculating metrics...")
|
| 506 |
metrics = calculate_metrics(pred_mask, gt_mask_processed)
|
| 507 |
metrics_text = "\n".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
|
| 508 |
-
print(f"Metrics calculated: {metrics}")
|
| 509 |
-
else:
|
| 510 |
-
print("Failed to process ground truth mask.")
|
| 511 |
|
| 512 |
# Create visualization
|
| 513 |
-
print("Creating visualization...")
|
| 514 |
fig = plt.figure(figsize=(12, 6))
|
| 515 |
|
| 516 |
if gt_mask_processed is not None:
|
|
@@ -549,15 +522,14 @@ def process_images(input_image=None, input_tiff=None, gt_mask_file=None):
|
|
| 549 |
if metrics_text:
|
| 550 |
result_text += f"\n\nEvaluation Metrics:\n{metrics_text}"
|
| 551 |
|
| 552 |
-
# Convert figure to image
|
| 553 |
-
buf = BytesIO()
|
| 554 |
plt.tight_layout()
|
| 555 |
-
|
|
|
|
| 556 |
buf.seek(0)
|
| 557 |
result_image = Image.open(buf)
|
| 558 |
plt.close(fig)
|
| 559 |
|
| 560 |
-
print("Processing completed successfully.")
|
| 561 |
return result_image, result_text
|
| 562 |
|
| 563 |
except Exception as e:
|
|
@@ -581,7 +553,7 @@ with gr.Blocks(title="Wetlands Segmentation from Satellite Imagery") as demo:
|
|
| 581 |
with gr.Tab("Upload TIFF"):
|
| 582 |
input_tiff = gr.File(label="Upload TIFF File", file_types=[".tif", ".tiff"])
|
| 583 |
|
| 584 |
-
#
|
| 585 |
gt_mask_file = gr.File(label="Ground Truth Mask (Optional)", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"])
|
| 586 |
|
| 587 |
process_btn = gr.Button("Analyze Image", variant="primary")
|
|
|
|
| 12 |
import urllib.request
|
| 13 |
import tempfile
|
| 14 |
import rasterio
|
|
|
|
| 15 |
import warnings
|
| 16 |
warnings.filterwarnings("ignore")
|
| 17 |
|
|
|
|
| 68 |
)
|
| 69 |
else:
|
| 70 |
# Fallback to a simple model that won't actually work but allows the UI to load
|
| 71 |
+
print("Warning: Using a placeholder model that won't produce valid predictions.")
|
| 72 |
from torch import nn
|
| 73 |
class PlaceholderModel(nn.Module):
|
| 74 |
def __init__(self):
|
|
|
|
| 132 |
This matches your training data loading approach
|
| 133 |
"""
|
| 134 |
try:
|
|
|
|
| 135 |
# Read the image using rasterio (get RGB channels)
|
| 136 |
with rasterio.open(tiff_path) as src:
|
|
|
|
| 137 |
# Check if we have enough bands
|
| 138 |
if src.count >= 3:
|
| 139 |
red = src.read(1)
|
|
|
|
| 142 |
|
| 143 |
# Stack to create RGB image
|
| 144 |
image = np.dstack((red, green, blue)).astype(np.float32)
|
|
|
|
| 145 |
|
| 146 |
# Normalize to [0, 1]
|
| 147 |
if image.max() > 0:
|
|
|
|
| 150 |
return image
|
| 151 |
else:
|
| 152 |
# If less than 3 bands, handle accordingly
|
|
|
|
| 153 |
bands = [src.read(i+1) for i in range(src.count)]
|
| 154 |
# If only one band, duplicate to create RGB
|
| 155 |
if len(bands) == 1:
|
|
|
|
| 167 |
return image
|
| 168 |
except Exception as e:
|
| 169 |
print(f"Error reading TIFF file: {e}")
|
|
|
|
|
|
|
| 170 |
return None
|
| 171 |
|
| 172 |
def read_tiff_mask(mask_path):
|
|
|
|
| 175 |
This matches your training data loading approach
|
| 176 |
"""
|
| 177 |
try:
|
|
|
|
| 178 |
# Read mask
|
| 179 |
with rasterio.open(mask_path) as src:
|
|
|
|
| 180 |
mask = src.read(1).astype(np.uint8)
|
|
|
|
| 181 |
return mask
|
| 182 |
except Exception as e:
|
| 183 |
print(f"Error reading mask file: {e}")
|
|
|
|
|
|
|
| 184 |
return None
|
| 185 |
|
| 186 |
def preprocess_image(image, target_size=(128, 128)):
|
|
|
|
| 204 |
|
| 205 |
# Convert PIL image to numpy
|
| 206 |
elif isinstance(image, Image.Image):
|
|
|
|
| 207 |
image = np.array(image)
|
| 208 |
|
| 209 |
# Ensure RGB format
|
|
|
|
| 221 |
print(f"Unsupported image type: {type(image)}")
|
| 222 |
return None, None
|
| 223 |
|
|
|
|
|
|
|
| 224 |
# Resize image to the target size
|
| 225 |
if albumentations_available:
|
| 226 |
# Use albumentations to match training preprocessing
|
|
|
|
| 240 |
|
| 241 |
return image_tensor, display_image
|
| 242 |
|
| 243 |
+
def extract_file_content(file_obj):
|
| 244 |
+
"""Extract content from the file object, handling different types"""
|
|
|
|
|
|
|
| 245 |
try:
|
| 246 |
+
if hasattr(file_obj, 'name') and isinstance(file_obj, str):
|
| 247 |
+
# Handle Gradio's NamedString
|
| 248 |
+
content = file_obj
|
| 249 |
+
if os.path.exists(content):
|
| 250 |
+
# It's a path
|
| 251 |
+
with open(content, 'rb') as f:
|
| 252 |
+
return f.read()
|
| 253 |
+
else:
|
| 254 |
+
# It's content
|
| 255 |
+
return content.encode('latin1')
|
| 256 |
+
elif hasattr(file_obj, 'read'):
|
| 257 |
+
# File-like object
|
| 258 |
+
return file_obj.read()
|
| 259 |
+
elif isinstance(file_obj, bytes):
|
| 260 |
+
# Already bytes
|
| 261 |
+
return file_obj
|
| 262 |
+
elif isinstance(file_obj, str):
|
| 263 |
+
# String path
|
| 264 |
+
if os.path.exists(file_obj):
|
|
|
|
| 265 |
with open(file_obj, 'rb') as f:
|
| 266 |
+
return f.read()
|
|
|
|
| 267 |
else:
|
| 268 |
+
return file_obj.encode('utf-8')
|
| 269 |
+
else:
|
| 270 |
+
print(f"Unsupported file object type: {type(file_obj)}")
|
| 271 |
+
return None
|
|
|
|
| 272 |
except Exception as e:
|
| 273 |
+
print(f"Error extracting file content: {e}")
|
|
|
|
|
|
|
| 274 |
return None
|
| 275 |
|
| 276 |
def process_uploaded_tiff(file_obj):
|
| 277 |
+
"""Process an uploaded TIFF file"""
|
|
|
|
|
|
|
| 278 |
try:
|
| 279 |
+
# Get file content
|
| 280 |
+
file_content = extract_file_content(file_obj)
|
| 281 |
+
if file_content is None:
|
| 282 |
+
print("Failed to extract file content")
|
| 283 |
+
return None, None
|
| 284 |
|
| 285 |
# Save to a temporary file
|
| 286 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as temp_file:
|
| 287 |
+
temp_path = temp_file.name
|
| 288 |
+
temp_file.write(file_content)
|
| 289 |
|
| 290 |
+
# Read as TIFF
|
| 291 |
image = read_tiff_image(temp_path)
|
| 292 |
|
| 293 |
# Clean up
|
|
|
|
| 323 |
return None, None
|
| 324 |
|
| 325 |
def process_uploaded_mask(file_obj):
|
| 326 |
+
"""Process an uploaded mask file"""
|
|
|
|
|
|
|
| 327 |
try:
|
| 328 |
+
# Get file content
|
| 329 |
+
file_content = extract_file_content(file_obj)
|
| 330 |
+
if file_content is None:
|
| 331 |
+
return None
|
| 332 |
|
| 333 |
# Save to a temporary file
|
| 334 |
+
# Determine suffix based on file name if available
|
| 335 |
+
suffix = '.tif'
|
| 336 |
+
if hasattr(file_obj, 'name'):
|
| 337 |
+
file_name = getattr(file_obj, 'name')
|
| 338 |
+
if isinstance(file_name, str) and '.' in file_name:
|
| 339 |
+
suffix = '.' + file_name.split('.')[-1].lower()
|
| 340 |
+
|
| 341 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
| 342 |
+
temp_path = temp_file.name
|
| 343 |
+
temp_file.write(file_content)
|
| 344 |
|
| 345 |
# Check if it's a TIFF file
|
| 346 |
if temp_path.lower().endswith(('.tif', '.tiff')):
|
|
|
|
| 352 |
mask = np.array(mask_img)
|
| 353 |
if len(mask.shape) == 3:
|
| 354 |
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
|
|
|
|
| 355 |
except Exception as e:
|
| 356 |
print(f"Error opening mask as regular image: {e}")
|
| 357 |
os.unlink(temp_path)
|
|
|
|
| 407 |
return pred
|
| 408 |
except Exception as e:
|
| 409 |
print(f"Error during prediction: {e}")
|
|
|
|
|
|
|
| 410 |
return None
|
| 411 |
|
| 412 |
def calculate_metrics(pred_mask, gt_mask):
|
|
|
|
| 449 |
Process input images and generate predictions
|
| 450 |
"""
|
| 451 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
# Check if we have input
|
| 453 |
if input_image is None and input_tiff is None:
|
| 454 |
return None, "Please upload an image or TIFF file."
|
|
|
|
| 468 |
return None, "No valid input provided."
|
| 469 |
|
| 470 |
# Get prediction
|
|
|
|
| 471 |
pred_mask = predict_segmentation(image_tensor)
|
| 472 |
if pred_mask is None:
|
| 473 |
return None, "Failed to generate prediction."
|
|
|
|
| 477 |
metrics_text = ""
|
| 478 |
|
| 479 |
if gt_mask_file is not None and gt_mask_file:
|
|
|
|
| 480 |
gt_mask_processed = process_uploaded_mask(gt_mask_file)
|
| 481 |
|
| 482 |
if gt_mask_processed is not None:
|
|
|
|
| 483 |
metrics = calculate_metrics(pred_mask, gt_mask_processed)
|
| 484 |
metrics_text = "\n".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
|
|
|
|
|
|
|
|
|
|
| 485 |
|
| 486 |
# Create visualization
|
|
|
|
| 487 |
fig = plt.figure(figsize=(12, 6))
|
| 488 |
|
| 489 |
if gt_mask_processed is not None:
|
|
|
|
| 522 |
if metrics_text:
|
| 523 |
result_text += f"\n\nEvaluation Metrics:\n{metrics_text}"
|
| 524 |
|
| 525 |
+
# Convert figure to image for display
|
|
|
|
| 526 |
plt.tight_layout()
|
| 527 |
+
buf = BytesIO()
|
| 528 |
+
plt.savefig(buf, format='png')
|
| 529 |
buf.seek(0)
|
| 530 |
result_image = Image.open(buf)
|
| 531 |
plt.close(fig)
|
| 532 |
|
|
|
|
| 533 |
return result_image, result_text
|
| 534 |
|
| 535 |
except Exception as e:
|
|
|
|
| 553 |
with gr.Tab("Upload TIFF"):
|
| 554 |
input_tiff = gr.File(label="Upload TIFF File", file_types=[".tif", ".tiff"])
|
| 555 |
|
| 556 |
+
# Ground truth mask as file upload
|
| 557 |
gt_mask_file = gr.File(label="Ground Truth Mask (Optional)", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"])
|
| 558 |
|
| 559 |
process_btn = gr.Button("Analyze Image", variant="primary")
|