File size: 6,525 Bytes
e972242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

import streamlit as st
import cv2
import numpy as np
from src.models import perform_custom_segmentation
from src.utils import resize_image, download_image
import os
import torch

# Constants
TARGET_SIZE = (750, 750)

def get_parameters_from_sidebar() -> dict:
    """Get segmentation parameters from sidebar"""
    st.sidebar.header("Segmentation Parameters")
    param_names = ['train_epoch', 'mod_dim1', 'mod_dim2', 'min_label_num', 'max_label_num']
    param_values = [(1, 200, 43), (1, 128, 67), (1, 128, 63), (1, 20, 3), (1, 200, 25)]
    params = {name: st.sidebar.slider(name.replace('_', ' ').title(), *values) for name, values in zip(param_names, param_values)}
    
    # Add sliders for target size width and height
    target_size_width = st.sidebar.number_input("Target Size Width", 100, 1200, 750)
    target_size_height = st.sidebar.number_input("Target Size Height", 100, 1200, 750)
    params['target_size'] = (target_size_width, target_size_height)
    
    return params
def display_segmentation_results() -> None:
    """Display segmentation results"""
    st.image(st.session_state.segmented_image, caption='Updated Segmented Image', use_column_width=True)

def randomize_colors() -> None:
    """Randomize colors for segmentation labels"""
    unique_labels = np.unique(st.session_state.segmented_image.reshape(-1, 3), axis=0)
    random_colors = {tuple(label): tuple(np.random.randint(0, 256, size=3)) for label in unique_labels}

    for old_color, new_color in random_colors.items():
        mask = np.all(st.session_state.segmented_image == np.array(old_color), axis=-1)
        st.session_state.segmented_image[mask] = new_color

    # Update color mappings in session state
    st.session_state.new_colors.update(random_colors)
    st.session_state.image_update_trigger += 1  # Trigger image update

def handle_color_picking() -> None:
    """Handle color picking and other functionalities"""
    unique_labels = np.unique(st.session_state.segmented_image.reshape(-1, 3), axis=0)
    for i, label in enumerate(unique_labels):
        hex_label = f'#{label[0]:02x}{label[1]:02x}{label[2]:02x}'
        new_color = st.color_picker(f"Choose a new color for label {i}", value=hex_label, key=f"label_{i}")
        new_color_rgb = tuple(int(new_color.lstrip('#')[j:j+2], 16) for j in (0, 2, 4))
        st.session_state.new_colors[tuple(label)] = new_color_rgb

    # Convert the new colors to hexadecimal for comparison
    new_colors_hex = {tuple(label): f'#{label[0]:02x}{label[1]:02x}{label[2]:02x}' for label in st.session_state.new_colors.values()}

    for old_color, new_color in st.session_state.new_colors.items():
        # Convert the old color to hexadecimal for comparison
        old_color_hex = f'#{old_color[0]:02x}{old_color[1]:02x}{old_color[2]:02x}'
        # Find the corresponding new color in hexadecimal
        new_color_hex = new_colors_hex[new_color]
        # Update the segmented image with the new color
        mask = np.all(st.session_state.segmented_image == np.array(old_color), axis=-1)
        st.session_state.segmented_image[mask] = new_color

    # After updating colors, trigger an update to the segmented image display
    st.session_state.image_update_trigger += 1

def calculate_and_display_label_percentages() -> None:
    """Calculate and display label percentages"""
    final_labels = cv2.cvtColor(st.session_state.segmented_image, cv2.COLOR_BGR2GRAY)
    unique_labels, counts = np.unique(final_labels, return_counts=True)
    total_pixels = np.sum(counts)
    label_percentages = {int(label): (count / total_pixels) * 100 for label, count in zip(unique_labels, counts)}

    # Create a mapping from grayscale label to RGB color
    label_to_color = {}
    for label in unique_labels:
        mask = final_labels == label
        corresponding_color = st.session_state.segmented_image[mask][0]
        hex_color = f'#{corresponding_color[0]:02x}{corresponding_color[1]:02x}{corresponding_color[2]:02x}'
        label_to_color[int(label)] = hex_color

    st.write("Label Percentages:")
    for label, percentage in label_percentages.items():
        hex_color = label_to_color[label]
        color_box = f'<div style="display: inline-block; width: 20px; height: 20px; background-color: {hex_color}; margin-right: 10px;"></div>'
        st.markdown(f'{color_box} Label {label}: {percentage:.2f}%', unsafe_allow_html=True)

def main() -> None:
    st.title("PetroSeg")
    st.info("""
    - **Training Epochs**: Higher values will lead to fewer segments but may take more time.
    - **Image Size**: For better efficiency, upload small-sized images.
    - **Cache**: For best results, clear the cache between different image uploads. You can do this from the menu in the top-right corner.
    """)

    if torch.cuda.is_available():
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)

    # Initialize session state if not already initialized
    if 'segmented_image' not in st.session_state:
        st.session_state.segmented_image = None
    if 'new_colors' not in st.session_state:
        st.session_state.new_colors = {}
    if 'image_update_trigger' not in st.session_state:
        st.session_state.image_update_trigger = 0

    # Define params before using it
    params = get_parameters_from_sidebar()

    uploaded_image = st.sidebar.file_uploader("Upload an image", type=["jpg", "png", "jpeg", "bmp", "tiff", "webp"])
    if uploaded_image:
        file_bytes = np.asarray(bytearray(uploaded_image.read()), dtype=np.uint8)
        image = cv2.imdecode(file_bytes, 1)

        if image is None:
            st.error("Error loading image. Please check the file and try again.")
            return

        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        st.image(image_rgb, caption='Original Image', use_column_width=True)

        # Use the target size specified by the user
        target_size = params['target_size']
        image_resized = resize_image(image_rgb, target_size)

        if st.sidebar.button("Start Segmentation"):
            st.session_state.segmented_image = perform_custom_segmentation(image_resized, params)

        if st.sidebar.button("Change Colors"):
            randomize_colors()

        if st.session_state.segmented_image is not None:
            handle_color_picking()
            display_segmentation_results()
            calculate_and_display_label_percentages()
            download_image(st.session_state.segmented_image, 'segmented_image.png')

if __name__ == "__main__":
    main()