Spaces:
Paused
Paused
| import streamlit as st | |
| import io | |
| import collections | |
| from scipy.io import loadmat | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| import argparse | |
| import torch.nn as nn | |
| import torch.utils.data as Data | |
| import torch.backends.cudnn as cudnn | |
| from scipy.io import loadmat | |
| from scipy.io import savemat | |
| from torch import optim | |
| from torch.autograd import Variable | |
| from sstvit import SSTViT | |
| from sklearn.metrics import confusion_matrix | |
| import matplotlib.pyplot as plt | |
| from matplotlib import colors | |
| import numpy as np | |
| from patchify import patchify, unpatchify | |
| import time | |
| from matplotlib import colors as mcolors | |
| import base64 | |
| import pandas as pd | |
| import st_aggrid | |
| import os | |
| import json | |
| import plotly.express as px | |
| css=''' | |
| <style> | |
| section.main > div {max-width:60rem} | |
| </style> | |
| ''' | |
| st.markdown(css, unsafe_allow_html=True) | |
| class Args(dict): | |
| __setattr__ = dict.__setitem__ | |
| __getattr__ = dict.__getitem__ | |
| args = { | |
| 'dataset' : 'mg', | |
| 'flag_test' : 'train', | |
| 'gpu_id' : 0, | |
| 'seed' : int(0), | |
| 'batch_size' : int(64), | |
| 'test_freq' : int(10), | |
| 'patches' : int(5), | |
| 'band_patches' : int(1), | |
| 'epoches' : int(2000), | |
| 'learning_rate' : float(5e-4), | |
| 'gamma' : float(0.9), | |
| 'weight_decay' : float(0), | |
| 'train_number' : int(500) | |
| } | |
| args = Args(args) # dict2object | |
| obj = args.copy() # object2dict | |
| os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) | |
| def test_epoch(model, test_loader): | |
| pre = np.array([]) | |
| for batch_idx, (batch_data_t1, batch_data_t2) in enumerate(test_loader): | |
| batch_data_t1 = batch_data_t1 | |
| batch_data_t2 = batch_data_t2 | |
| batch_pred = model(batch_data_t1,batch_data_t2) | |
| _, pred = batch_pred.topk(1, 1, True, True) | |
| pp = pred.squeeze() | |
| pre = np.append(pre, pp.data.cpu().numpy()) | |
| return pre | |
| mdic = ['Before','After','Before','After'] | |
| colors = ['#3b68f8', '#ff0201', '#23fe01'] #-1,0,1,2,3 | |
| cmap = mcolors.ListedColormap(colors) | |
| # Parameter Setting | |
| np.random.seed(args.seed) | |
| torch.manual_seed(args.seed) | |
| torch.cuda.manual_seed(args.seed) | |
| cudnn.deterministic = True | |
| cudnn.benchmark = False | |
| def encode_masks_to_rgb(masks): | |
| colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0)] | |
| # Create an empty RGB image | |
| height, width = masks.shape | |
| rgb_image = np.zeros((height, width, 3), dtype=np.uint8) | |
| # Assign colors based on the mask values | |
| for i in range(len(colors)): | |
| mask_indices = masks == i | |
| rgb_image[mask_indices] = colors[i] | |
| return rgb_image | |
| def count_pixel(pred): | |
| image = Image.fromarray(pred) | |
| # Define the colors you want to count in RGB format | |
| color2label = { | |
| (0, 0, 255): "Non Mangrove", | |
| (255, 0, 0): "Mangrove Loss", | |
| (0, 255, 0): "Mangrove Before", | |
| } | |
| # Create a flattened list of pixel values | |
| pixels = list(image.getdata()) | |
| # Count the number of pixels for each color | |
| color_counts = collections.Counter(pixels) | |
| # Calculate the total number of pixels in the image | |
| total_pixels = len(pixels) | |
| # Initialize a dictionary to store the average number of pixels for each class | |
| average_counts = {color2label[label]: (count / total_pixels)*100 for label, count in color_counts.items()} | |
| class_counts = {color2label[label]: count for label, count in color_counts.items()} | |
| pix_avg = {} | |
| pix_count = {} | |
| for _, i in color2label.items(): | |
| try: | |
| pix_avg[i] = average_counts[i] | |
| pix_count[i] = class_counts[i] | |
| except: | |
| pix_avg[i] = 0 | |
| pix_count[i] = 0 | |
| x = { | |
| "class": list(pix_avg.keys()), | |
| "percentage": list(pix_avg.values()), | |
| "pixel_count": list(pix_count.values()) | |
| } | |
| # print(x) | |
| return pd.DataFrame(x) | |
| def count_pixel1(pred): | |
| image = Image.fromarray(pred) | |
| # Define the colors you want to count in RGB format | |
| color2label = { | |
| (0, 0, 255): "Non Mangrove", | |
| (255, 0, 0): "Mangrove Loss", | |
| (0, 255, 0): "Mangrove After", | |
| } | |
| # Create a flattened list of pixel values | |
| pixels = list(image.getdata()) | |
| # Count the number of pixels for each color | |
| color_counts = collections.Counter(pixels) | |
| # Calculate the total number of pixels in the image | |
| total_pixels = len(pixels) | |
| # Initialize a dictionary to store the average number of pixels for each class | |
| average_counts = {color2label[label]: (count / total_pixels)*100 for label, count in color_counts.items()} | |
| class_counts = {color2label[label]: count for label, count in color_counts.items()} | |
| pix_avg = {} | |
| pix_count = {} | |
| for _, i in color2label.items(): | |
| try: | |
| pix_avg[i] = average_counts[i] | |
| pix_count[i] = class_counts[i] | |
| except: | |
| pix_avg[i] = 0 | |
| pix_count[i] = 0 | |
| x = { | |
| "class": list(pix_avg.keys()), | |
| "percentage": list(pix_avg.values()), | |
| "pixel_count": list(pix_count.values()) | |
| } | |
| # print(x) | |
| return pd.DataFrame(x) | |
| file = st.file_uploader("Upload file", type=['mat']) | |
| if file: | |
| data_img2 = loadmat(file)['data_img2'] | |
| data_img1 = loadmat(file)['data_img1'] | |
| st.subheader("Preview Dataset") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| fig = plt.figure(figsize=(5, 5)) | |
| plt.subplot(121) | |
| plt.imshow(data_img1) | |
| plt.title('Before', fontweight='bold') | |
| plt.xticks([]) | |
| plt.yticks([]) | |
| plt.subplot(122) | |
| plt.imshow(data_img2) | |
| plt.title('After', fontweight='bold') | |
| plt.xticks([]) | |
| plt.yticks([]) | |
| plt.show() | |
| st.pyplot(fig) | |
| holder = st.empty() | |
| if holder.button("Start Prediction"): | |
| start = time.time() | |
| holder.empty() | |
| with st.spinner("Processing, please wait around 7-15 minute"): | |
| data_t1 = loadmat(file)['data_t1'] | |
| data_t2 = loadmat(file)['data_t2'] | |
| L_post = loadmat(file)['L_post'] | |
| L_pre = loadmat(file)['L_pre'] | |
| data_img1 = loadmat(file)['data_img1'] | |
| data_img2 = loadmat(file)['data_img2'] | |
| L_post = np.double(L_post) | |
| L_post[L_post==0]=-0.8 | |
| L_post[L_post==1]=0 | |
| L_post[L_post==0]=-0.2 | |
| L_pre = np.double(L_pre) | |
| L_pre[L_pre==0]=-0.8 | |
| L_pre[L_pre==1]=0 | |
| L_pre[L_pre==0]=-0.2 | |
| data_t1 = data_t1[:L_post.shape[0],:L_post.shape[1],:] | |
| data_t2 = data_t2[:L_post.shape[0],:L_post.shape[1],:] | |
| data_cb1 = np.zeros(shape=(L_post.shape[0],L_post.shape[1],11),dtype=np.float32) | |
| data_cb2 = np.zeros(shape=(L_post.shape[0],L_post.shape[1],11),dtype=np.float32) | |
| data_cb1[:,:,:10]=data_t1 | |
| data_cb1[:,:,10]=L_pre | |
| data_cb2[:,:,:10]=data_t2 | |
| data_cb2[:,:,10]=L_post | |
| height, width, band = data_cb1.shape | |
| height=height-4 | |
| width = width-4 | |
| x1 = patchify(data_cb1, (5, 5, 11), step=1).reshape(-1,5*5, 11) | |
| x2 = patchify(data_cb2, (5, 5, 11), step=1).reshape(-1,5*5, 11) | |
| # create model | |
| model = SSTViT( | |
| image_size = 5, | |
| near_band = args.band_patches, | |
| num_patches = 11, | |
| num_classes = 3, | |
| dim = 32, | |
| depth = 2, | |
| heads = 4, | |
| dim_head=16, | |
| mlp_dim = 8, | |
| b_dim = 512, | |
| b_depth = 3, | |
| b_heads = 8, | |
| b_dim_head= 32, | |
| b_mlp_head = 8, | |
| dropout = 0.2, | |
| emb_dropout = 0.1, | |
| ) | |
| model.load_state_dict(torch.load("model/lsstformer.pth",map_location=torch.device("cpu"))) | |
| x1_true_band=torch.from_numpy(x1.transpose(0,2,1)).type(torch.FloatTensor) | |
| x2_true_band=torch.from_numpy(x1.transpose(0,2,1)).type(torch.FloatTensor) | |
| Label_true=Data.TensorDataset(x1_true_band,x2_true_band) | |
| label_true_loader=Data.DataLoader(Label_true,batch_size=100,shuffle=False) | |
| model.eval() | |
| # output classification maps | |
| pre_u = test_epoch(model, label_true_loader) | |
| prediction_matrix = pre_u.reshape(height,width) | |
| x1_true_band=torch.from_numpy(x1.transpose(0,2,1)).type(torch.FloatTensor) | |
| x2_true_band=torch.from_numpy(x2.transpose(0,2,1)).type(torch.FloatTensor) | |
| Label_true=Data.TensorDataset(x1_true_band,x2_true_band) | |
| label_true_loader=Data.DataLoader(Label_true,batch_size=100,shuffle=False) | |
| model.eval() | |
| # output classification maps | |
| pre_u = test_epoch(model, label_true_loader) | |
| prediction_matrix2 = pre_u.reshape(height,width) | |
| A = prediction_matrix.reshape(-1) | |
| B = prediction_matrix2.reshape(-1) | |
| mg = np.array(np.where(A==2)) | |
| mg1 = np.array(np.where(B==2)) | |
| mgls = np.array(np.where(B==1)) | |
| class_counts = count_pixel(encode_masks_to_rgb(prediction_matrix)) | |
| class_counts1 = count_pixel1(encode_masks_to_rgb(prediction_matrix2)) | |
| with st.container(): | |
| st.subheader("Prediction Result") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| with st.container(): | |
| fig = plt.figure(figsize=(10, 10)) | |
| plt.subplot(121) | |
| plt.imshow(prediction_matrix, cmap=cmap) | |
| plt.title('Before',fontsize=25, fontweight='bold') | |
| plt.xticks([]) | |
| plt.yticks([]) | |
| plt.subplot(122) | |
| plt.imshow(prediction_matrix2, cmap=cmap) | |
| plt.title('After',fontsize=25, fontweight='bold') | |
| plt.xticks([]) | |
| plt.yticks([]) | |
| plt.show() | |
| st.pyplot(fig) | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png") | |
| with col2: | |
| with st.container(): | |
| table_data = { | |
| "Total mangrove before":f"{mg.shape[1]*100} m\u00B2", | |
| "Total mangrove after":f"{mg1.shape[1]*100} m\u00B2", | |
| "Total mangrove loss":f"{mgls.shape[1]*100} m\u00B2", | |
| } | |
| df = pd.DataFrame(list(table_data.items()), columns=['Key', 'Value']) | |
| MIN_HEIGHT = 100 | |
| MAX_HEIGHT = 180 | |
| ROW_HEIGHT = 50 | |
| # st.dataframe(df, hide_index=True, use_container_width=True) | |
| st_aggrid.AgGrid(df,fit_columns_on_grid_load=True, height=min(MIN_HEIGHT + len(df) * ROW_HEIGHT, MAX_HEIGHT)) | |
| with st.container(): | |
| st.subheader("Pixel Distribution") | |
| df = class_counts | |
| df = df.drop(0) | |
| df1 = df.drop(1) | |
| df2 = class_counts1 | |
| df3 = df2.drop(0) | |
| vertical_concat = pd.concat([df1, df3], axis=0) | |
| MIN_HEIGHT = 100 | |
| MAX_HEIGHT = 180 | |
| ROW_HEIGHT = 50 | |
| vertical_concat = vertical_concat.iloc[[0,2,1],:] | |
| st_aggrid.AgGrid(vertical_concat,fit_columns_on_grid_load=True, height=min(MIN_HEIGHT + len(vertical_concat) * ROW_HEIGHT, MAX_HEIGHT)) | |
| fig = px.bar(vertical_concat, x='percentage', y='class', color='class', orientation='h', | |
| color_discrete_sequence=["green","green", "red", "blue"], | |
| category_orders={"class": ["Mangrove Before","Mangrove After", "Mangrove Loss", "Non Mangrove",]} | |
| ) | |
| st.plotly_chart(fig,use_container_width=False) | |
| end = time.time() | |
| process = end-start | |
| st.write('process',process) | |
| show_file = st.empty() | |
| if not file: | |
| url = "https://drive.usercontent.google.com/download?id=1u48pMzRWQ2Etfjaq5A0CUjRtGKZaJoJy&export=download&authuser=2&confirm=t&uuid=52b0e01e-377f-42cb-8412-c84aa38a1740&at=APZUnTXslmuCCV1drJ2WWtkZr9BR%3A1710357675310" | |
| show_file.info(""" | |
| The model was trained using Sentinel-2 imagery, users can upload MAT files to perform LSST-Former for mangrove loss detection models that have been trained in this research. Tool for generate from Sentinel-2 to MAT file i will create later, please download demo dataset bellow. for better in mobile phone, use desktop mode. | |
| """) | |
| st.write("download demo datasets this [link](%s)" % url) |