Spaces:
Sleeping
Sleeping
import torch | |
import streamlit as st | |
import numpy as np | |
from PIL import Image | |
from unet import UNet | |
from torchvision import transforms | |
from torchvision.transforms.functional import to_tensor, to_pil_image | |
import matplotlib.pyplot as plt | |
from torch.utils.data import Dataset, DataLoader | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
device = torch.device(device) | |
# Load the trained model | |
model_path = 'cityscapes_dataUNet.pth' | |
num_classes = 10 | |
model = UNet(num_classes=num_classes) | |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
model.to(device) | |
model.eval() | |
# Define the prediction function that takes an input image and returns the segmented image | |
def predict_segmentation(image): | |
st.write(device) | |
# Convert the input image to a PyTorch tensor and normalize it | |
image = Image.fromarray(image, 'RGB') | |
# image = transforms.functional.resize(image, (256, 256)) | |
image = to_tensor(image).unsqueeze(0) | |
image = transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))(image) | |
image = image.to(device) | |
st.write("Input shape:", image.shape) # input shape torch.Size([1, 3, 256, 256]) | |
st.write("Input dtype:", image.dtype) # input dtype torch.float32 | |
# Make a prediction using the model | |
with torch.no_grad(): | |
st.write(image.shape, image.dtype) # torch.Size([1, 3, 256, 256]) torch.float32 | |
output = model(image) | |
predicted_class = torch.argmax(output, dim=1).squeeze(0) | |
predicted_class = predicted_class.cpu().detach().numpy().astype(np.uint8) | |
st.write("Predicted class dtype:", predicted_class.dtype) | |
st.write("Predicted class shape:", predicted_class.shape) | |
# Visualize the predicted segmentation mask | |
plt.imshow(predicted_class) | |
st.pyplot(plt) | |
st.write("Predicted class:", predicted_class) | |
# Return the predicted segmentation | |
return predicted_class | |
# Define the Streamlit interface | |
st.title('UNet Image Segmentation IPPR') | |
st.write('Segment an image using a UNet model') | |
uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) | |
if uploaded_image is not None: | |
# Read the uploaded image | |
image = Image.open(uploaded_image) | |
st.image(image, caption='Uploaded Image', use_column_width=True) | |
# Process the image and get the segmented result | |
segmented_image = predict_segmentation(np.array(image)) | |
# Display the segmented image | |
# st.image(segmented_image, caption='Segmented Image', use_column_width=True) | |