IPPR_IMAGE_SEG / app.py
Manthanx's picture
Upload 4 files
3ede0f3
raw
history blame
2.51 kB
import torch
import streamlit as st
import numpy as np
from PIL import Image
from unet import UNet
from torchvision import transforms
from torchvision.transforms.functional import to_tensor, to_pil_image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
# Load the trained model
model_path = 'cityscapes_dataUNet.pth'
num_classes = 10
model = UNet(num_classes=num_classes)
model.load_state_dict(torch.load(model_path))
model.to(device)
model.eval()
# Define the prediction function that takes an input image and returns the segmented image
def predict_segmentation(image):
st.write(device)
# Convert the input image to a PyTorch tensor and normalize it
image = Image.fromarray(image, 'RGB')
# image = transforms.functional.resize(image, (256, 256))
image = to_tensor(image).unsqueeze(0)
image = transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))(image)
image = image.to(device)
st.write("Input shape:", image.shape) # input shape torch.Size([1, 3, 256, 256])
st.write("Input dtype:", image.dtype) # input dtype torch.float32
# Make a prediction using the model
with torch.no_grad():
st.write(image.shape, image.dtype) # torch.Size([1, 3, 256, 256]) torch.float32
output = model(image)
predicted_class = torch.argmax(output, dim=1).squeeze(0)
predicted_class = predicted_class.cpu().detach().numpy().astype(np.uint8)
st.write("Predicted class dtype:", predicted_class.dtype)
st.write("Predicted class shape:", predicted_class.shape)
# Visualize the predicted segmentation mask
plt.imshow(predicted_class)
st.pyplot(plt)
st.write("Predicted class:", predicted_class)
# Return the predicted segmentation
return predicted_class
# Define the Streamlit interface
st.title('UNet Image Segmentation')
st.write('Segment an image using a UNet model')
uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
if uploaded_image is not None:
# Read the uploaded image
image = Image.open(uploaded_image)
st.image(image, caption='Uploaded Image', use_column_width=True)
# Process the image and get the segmented result
segmented_image = predict_segmentation(np.array(image))
# Display the segmented image
st.image(segmented_image, caption='Segmented Image', use_column_width=True)