File size: 6,525 Bytes
e972242 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import streamlit as st
import cv2
import numpy as np
from src.models import perform_custom_segmentation
from src.utils import resize_image, download_image
import os
import torch
# Constants
TARGET_SIZE = (750, 750)
def get_parameters_from_sidebar() -> dict:
"""Get segmentation parameters from sidebar"""
st.sidebar.header("Segmentation Parameters")
param_names = ['train_epoch', 'mod_dim1', 'mod_dim2', 'min_label_num', 'max_label_num']
param_values = [(1, 200, 43), (1, 128, 67), (1, 128, 63), (1, 20, 3), (1, 200, 25)]
params = {name: st.sidebar.slider(name.replace('_', ' ').title(), *values) for name, values in zip(param_names, param_values)}
# Add sliders for target size width and height
target_size_width = st.sidebar.number_input("Target Size Width", 100, 1200, 750)
target_size_height = st.sidebar.number_input("Target Size Height", 100, 1200, 750)
params['target_size'] = (target_size_width, target_size_height)
return params
def display_segmentation_results() -> None:
"""Display segmentation results"""
st.image(st.session_state.segmented_image, caption='Updated Segmented Image', use_column_width=True)
def randomize_colors() -> None:
"""Randomize colors for segmentation labels"""
unique_labels = np.unique(st.session_state.segmented_image.reshape(-1, 3), axis=0)
random_colors = {tuple(label): tuple(np.random.randint(0, 256, size=3)) for label in unique_labels}
for old_color, new_color in random_colors.items():
mask = np.all(st.session_state.segmented_image == np.array(old_color), axis=-1)
st.session_state.segmented_image[mask] = new_color
# Update color mappings in session state
st.session_state.new_colors.update(random_colors)
st.session_state.image_update_trigger += 1 # Trigger image update
def handle_color_picking() -> None:
"""Handle color picking and other functionalities"""
unique_labels = np.unique(st.session_state.segmented_image.reshape(-1, 3), axis=0)
for i, label in enumerate(unique_labels):
hex_label = f'#{label[0]:02x}{label[1]:02x}{label[2]:02x}'
new_color = st.color_picker(f"Choose a new color for label {i}", value=hex_label, key=f"label_{i}")
new_color_rgb = tuple(int(new_color.lstrip('#')[j:j+2], 16) for j in (0, 2, 4))
st.session_state.new_colors[tuple(label)] = new_color_rgb
# Convert the new colors to hexadecimal for comparison
new_colors_hex = {tuple(label): f'#{label[0]:02x}{label[1]:02x}{label[2]:02x}' for label in st.session_state.new_colors.values()}
for old_color, new_color in st.session_state.new_colors.items():
# Convert the old color to hexadecimal for comparison
old_color_hex = f'#{old_color[0]:02x}{old_color[1]:02x}{old_color[2]:02x}'
# Find the corresponding new color in hexadecimal
new_color_hex = new_colors_hex[new_color]
# Update the segmented image with the new color
mask = np.all(st.session_state.segmented_image == np.array(old_color), axis=-1)
st.session_state.segmented_image[mask] = new_color
# After updating colors, trigger an update to the segmented image display
st.session_state.image_update_trigger += 1
def calculate_and_display_label_percentages() -> None:
"""Calculate and display label percentages"""
final_labels = cv2.cvtColor(st.session_state.segmented_image, cv2.COLOR_BGR2GRAY)
unique_labels, counts = np.unique(final_labels, return_counts=True)
total_pixels = np.sum(counts)
label_percentages = {int(label): (count / total_pixels) * 100 for label, count in zip(unique_labels, counts)}
# Create a mapping from grayscale label to RGB color
label_to_color = {}
for label in unique_labels:
mask = final_labels == label
corresponding_color = st.session_state.segmented_image[mask][0]
hex_color = f'#{corresponding_color[0]:02x}{corresponding_color[1]:02x}{corresponding_color[2]:02x}'
label_to_color[int(label)] = hex_color
st.write("Label Percentages:")
for label, percentage in label_percentages.items():
hex_color = label_to_color[label]
color_box = f'<div style="display: inline-block; width: 20px; height: 20px; background-color: {hex_color}; margin-right: 10px;"></div>'
st.markdown(f'{color_box} Label {label}: {percentage:.2f}%', unsafe_allow_html=True)
def main() -> None:
st.title("PetroSeg")
st.info("""
- **Training Epochs**: Higher values will lead to fewer segments but may take more time.
- **Image Size**: For better efficiency, upload small-sized images.
- **Cache**: For best results, clear the cache between different image uploads. You can do this from the menu in the top-right corner.
""")
if torch.cuda.is_available():
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
# Initialize session state if not already initialized
if 'segmented_image' not in st.session_state:
st.session_state.segmented_image = None
if 'new_colors' not in st.session_state:
st.session_state.new_colors = {}
if 'image_update_trigger' not in st.session_state:
st.session_state.image_update_trigger = 0
# Define params before using it
params = get_parameters_from_sidebar()
uploaded_image = st.sidebar.file_uploader("Upload an image", type=["jpg", "png", "jpeg", "bmp", "tiff", "webp"])
if uploaded_image:
file_bytes = np.asarray(bytearray(uploaded_image.read()), dtype=np.uint8)
image = cv2.imdecode(file_bytes, 1)
if image is None:
st.error("Error loading image. Please check the file and try again.")
return
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
st.image(image_rgb, caption='Original Image', use_column_width=True)
# Use the target size specified by the user
target_size = params['target_size']
image_resized = resize_image(image_rgb, target_size)
if st.sidebar.button("Start Segmentation"):
st.session_state.segmented_image = perform_custom_segmentation(image_resized, params)
if st.sidebar.button("Change Colors"):
randomize_colors()
if st.session_state.segmented_image is not None:
handle_color_picking()
display_segmentation_results()
calculate_and_display_label_percentages()
download_image(st.session_state.segmented_image, 'segmented_image.png')
if __name__ == "__main__":
main()
|