Manthanx commited on
Commit
3ede0f3
·
1 Parent(s): 55a2d33

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +68 -0
  2. cityscapes_dataUNet.pth +3 -0
  3. requirements.txt +6 -0
  4. unet.py +57 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import streamlit as st
3
+ import numpy as np
4
+ from PIL import Image
5
+ from unet import UNet
6
+ from torchvision import transforms
7
+ from torchvision.transforms.functional import to_tensor, to_pil_image
8
+ import matplotlib.pyplot as plt
9
+ from torch.utils.data import Dataset, DataLoader
10
+
11
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
12
+ device = torch.device(device)
13
+ # Load the trained model
14
+ model_path = 'cityscapes_dataUNet.pth'
15
+ num_classes = 10
16
+ model = UNet(num_classes=num_classes)
17
+ model.load_state_dict(torch.load(model_path))
18
+ model.to(device)
19
+ model.eval()
20
+
21
+ # Define the prediction function that takes an input image and returns the segmented image
22
+ def predict_segmentation(image):
23
+ st.write(device)
24
+ # Convert the input image to a PyTorch tensor and normalize it
25
+ image = Image.fromarray(image, 'RGB')
26
+ # image = transforms.functional.resize(image, (256, 256))
27
+ image = to_tensor(image).unsqueeze(0)
28
+ image = transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))(image)
29
+ image = image.to(device)
30
+
31
+ st.write("Input shape:", image.shape) # input shape torch.Size([1, 3, 256, 256])
32
+ st.write("Input dtype:", image.dtype) # input dtype torch.float32
33
+
34
+ # Make a prediction using the model
35
+ with torch.no_grad():
36
+ st.write(image.shape, image.dtype) # torch.Size([1, 3, 256, 256]) torch.float32
37
+
38
+ output = model(image)
39
+ predicted_class = torch.argmax(output, dim=1).squeeze(0)
40
+ predicted_class = predicted_class.cpu().detach().numpy().astype(np.uint8)
41
+ st.write("Predicted class dtype:", predicted_class.dtype)
42
+ st.write("Predicted class shape:", predicted_class.shape)
43
+
44
+ # Visualize the predicted segmentation mask
45
+ plt.imshow(predicted_class)
46
+ st.pyplot(plt)
47
+
48
+ st.write("Predicted class:", predicted_class)
49
+
50
+ # Return the predicted segmentation
51
+ return predicted_class
52
+
53
+ # Define the Streamlit interface
54
+ st.title('UNet Image Segmentation')
55
+ st.write('Segment an image using a UNet model')
56
+
57
+ uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
58
+
59
+ if uploaded_image is not None:
60
+ # Read the uploaded image
61
+ image = Image.open(uploaded_image)
62
+ st.image(image, caption='Uploaded Image', use_column_width=True)
63
+
64
+ # Process the image and get the segmented result
65
+ segmented_image = predict_segmentation(np.array(image))
66
+
67
+ # Display the segmented image
68
+ st.image(segmented_image, caption='Segmented Image', use_column_width=True)
cityscapes_dataUNet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f41259bc794efd66defb8c0029ea054dcf4bb0b98dcc3229e36267782132315
3
+ size 138216745
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==3.33.1
2
+ matplotlib==3.6.2
3
+ numpy==1.24.2
4
+ Pillow==9.3.0
5
+ torch==2.0.0
6
+ torchvision==0.15.1
unet.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ class UNet(nn.Module):
6
+
7
+ def __init__(self, num_classes):
8
+ super(UNet, self).__init__()
9
+ self.num_classes = num_classes
10
+ self.contracting_11 = self.conv_block(in_channels=3, out_channels=64)
11
+ self.contracting_12 = nn.MaxPool2d(kernel_size=2, stride=2)
12
+ self.contracting_21 = self.conv_block(in_channels=64, out_channels=128)
13
+ self.contracting_22 = nn.MaxPool2d(kernel_size=2, stride=2)
14
+ self.contracting_31 = self.conv_block(in_channels=128, out_channels=256)
15
+ self.contracting_32 = nn.MaxPool2d(kernel_size=2, stride=2)
16
+ self.contracting_41 = self.conv_block(in_channels=256, out_channels=512)
17
+ self.contracting_42 = nn.MaxPool2d(kernel_size=2, stride=2)
18
+ self.middle = self.conv_block(in_channels=512, out_channels=1024)
19
+ self.expansive_11 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, padding=1, output_padding=1)
20
+ self.expansive_12 = self.conv_block(in_channels=1024, out_channels=512)
21
+ self.expansive_21 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1)
22
+ self.expansive_22 = self.conv_block(in_channels=512, out_channels=256)
23
+ self.expansive_31 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1)
24
+ self.expansive_32 = self.conv_block(in_channels=256, out_channels=128)
25
+ self.expansive_41 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1)
26
+ self.expansive_42 = self.conv_block(in_channels=128, out_channels=64)
27
+ self.output = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=3, stride=1, padding=1)
28
+
29
+ def conv_block(self, in_channels, out_channels):
30
+ block = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
31
+ nn.ReLU(),
32
+ nn.BatchNorm2d(num_features=out_channels),
33
+ nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
34
+ nn.ReLU(),
35
+ nn.BatchNorm2d(num_features=out_channels))
36
+ return block
37
+
38
+ def forward(self, X):
39
+ contracting_11_out = self.contracting_11(X) # [-1, 64, 256, 256]
40
+ contracting_12_out = self.contracting_12(contracting_11_out) # [-1, 64, 128, 128]
41
+ contracting_21_out = self.contracting_21(contracting_12_out) # [-1, 128, 128, 128]
42
+ contracting_22_out = self.contracting_22(contracting_21_out) # [-1, 128, 64, 64]
43
+ contracting_31_out = self.contracting_31(contracting_22_out) # [-1, 256, 64, 64]
44
+ contracting_32_out = self.contracting_32(contracting_31_out) # [-1, 256, 32, 32]
45
+ contracting_41_out = self.contracting_41(contracting_32_out) # [-1, 512, 32, 32]
46
+ contracting_42_out = self.contracting_42(contracting_41_out) # [-1, 512, 16, 16]
47
+ middle_out = self.middle(contracting_42_out) # [-1, 1024, 16, 16]
48
+ expansive_11_out = self.expansive_11(middle_out) # [-1, 512, 32, 32]
49
+ expansive_12_out = self.expansive_12(torch.cat((expansive_11_out, contracting_41_out), dim=1)) # [-1, 1024, 32, 32] -> [-1, 512, 32, 32]
50
+ expansive_21_out = self.expansive_21(expansive_12_out) # [-1, 256, 64, 64]
51
+ expansive_22_out = self.expansive_22(torch.cat((expansive_21_out, contracting_31_out), dim=1)) # [-1, 512, 64, 64] -> [-1, 256, 64, 64]
52
+ expansive_31_out = self.expansive_31(expansive_22_out) # [-1, 128, 128, 128]
53
+ expansive_32_out = self.expansive_32(torch.cat((expansive_31_out, contracting_21_out), dim=1)) # [-1, 256, 128, 128] -> [-1, 128, 128, 128]
54
+ expansive_41_out = self.expansive_41(expansive_32_out) # [-1, 64, 256, 256]
55
+ expansive_42_out = self.expansive_42(torch.cat((expansive_41_out, contracting_11_out), dim=1)) # [-1, 128, 256, 256] -> [-1, 64, 256, 256]
56
+ output_out = self.output(expansive_42_out) # [-1, num_classes, 256, 256]
57
+ return output_out