import torch import streamlit as st import numpy as np from PIL import Image from unet import UNet from torchvision import transforms from torchvision.transforms.functional import to_tensor, to_pil_image import matplotlib.pyplot as plt from torch.utils.data import Dataset, DataLoader device = "cuda:0" if torch.cuda.is_available() else "cpu" device = torch.device(device) # Load the trained model model_path = 'cityscapes_dataUNet.pth' num_classes = 10 model = UNet(num_classes=num_classes) model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.to(device) model.eval() # Define the prediction function that takes an input image and returns the segmented image def predict_segmentation(image): st.write(device) # Convert the input image to a PyTorch tensor and normalize it image = Image.fromarray(image, 'RGB') # image = transforms.functional.resize(image, (256, 256)) image = to_tensor(image).unsqueeze(0) image = transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))(image) image = image.to(device) st.write("Input shape:", image.shape) # input shape torch.Size([1, 3, 256, 256]) st.write("Input dtype:", image.dtype) # input dtype torch.float32 # Make a prediction using the model with torch.no_grad(): st.write(image.shape, image.dtype) # torch.Size([1, 3, 256, 256]) torch.float32 output = model(image) predicted_class = torch.argmax(output, dim=1).squeeze(0) predicted_class = predicted_class.cpu().detach().numpy().astype(np.uint8) st.write("Predicted class dtype:", predicted_class.dtype) st.write("Predicted class shape:", predicted_class.shape) # Visualize the predicted segmentation mask plt.imshow(predicted_class) st.pyplot(plt) st.write("Predicted class:", predicted_class) # Return the predicted segmentation return predicted_class # Define the Streamlit interface st.title('UNet Image Segmentation IPPR') st.write('Segment an image using a UNet model') uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) if uploaded_image is not None: # Read the uploaded image image = Image.open(uploaded_image) st.image(image, caption='Uploaded Image', use_column_width=True) # Process the image and get the segmented result segmented_image = predict_segmentation(np.array(image)) # Display the segmented image # st.image(segmented_image, caption='Segmented Image', use_column_width=True)