IPPR_IMAGE_SEG / app.py
Manthanx's picture
Update app.py
4f9dc10
raw
history blame
2.55 kB
import torch
import streamlit as st
import numpy as np
from PIL import Image
from unet import UNet
from torchvision import transforms
from torchvision.transforms.functional import to_tensor, to_pil_image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
# Load the trained model
model_path = 'cityscapes_dataUNet.pth'
num_classes = 10
model = UNet(num_classes=num_classes)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.to(device)
model.eval()
# Define the prediction function that takes an input image and returns the segmented image
def predict_segmentation(image):
st.write(device)
# Convert the input image to a PyTorch tensor and normalize it
image = Image.fromarray(image, 'RGB')
# image = transforms.functional.resize(image, (256, 256))
image = to_tensor(image).unsqueeze(0)
image = transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))(image)
image = image.to(device)
st.write("Input shape:", image.shape) # input shape torch.Size([1, 3, 256, 256])
st.write("Input dtype:", image.dtype) # input dtype torch.float32
# Make a prediction using the model
with torch.no_grad():
st.write(image.shape, image.dtype) # torch.Size([1, 3, 256, 256]) torch.float32
output = model(image)
predicted_class = torch.argmax(output, dim=1).squeeze(0)
predicted_class = predicted_class.cpu().detach().numpy().astype(np.uint8)
st.write("Predicted class dtype:", predicted_class.dtype)
st.write("Predicted class shape:", predicted_class.shape)
# Visualize the predicted segmentation mask
plt.imshow(predicted_class)
st.pyplot(plt)
st.write("Predicted class:", predicted_class)
# Return the predicted segmentation
return predicted_class
# Define the Streamlit interface
st.title('UNet Image Segmentation IPPR')
st.write('Segment an image using a UNet model')
uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
if uploaded_image is not None:
# Read the uploaded image
image = Image.open(uploaded_image)
st.image(image, caption='Uploaded Image', use_column_width=True)
# Process the image and get the segmented result
segmented_image = predict_segmentation(np.array(image))
# Display the segmented image
# st.image(segmented_image, caption='Segmented Image', use_column_width=True)