svsaurav95 commited on
Commit
061386e
·
verified ·
1 Parent(s): 54dc4e9

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +91 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,93 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import timm
5
+ import numpy as np
6
+ import cv2
7
+ from PIL import Image
8
+ import io
9
+
10
+ # Hide Streamlit warnings and UI elements
11
+ st.set_page_config(layout="wide")
12
+ st.markdown("""
13
+ <style>
14
+ footer {visibility: hidden;}
15
+ </style>
16
+ """, unsafe_allow_html=True)
17
+
18
+ # === Model Definition ===
19
+ class MobileViTSegmentation(nn.Module):
20
+ def __init__(self, encoder_name='mobilevit_s', pretrained=False):
21
+ super().__init__()
22
+ self.backbone = timm.create_model(encoder_name, features_only=True, pretrained=pretrained)
23
+ self.encoder_channels = self.backbone.feature_info.channels()
24
+
25
+ self.decoder = nn.Sequential(
26
+ nn.Conv2d(self.encoder_channels[-1], 128, kernel_size=3, padding=1),
27
+ nn.Upsample(scale_factor=2, mode='bilinear'),
28
+ nn.Conv2d(128, 64, kernel_size=3, padding=1),
29
+ nn.Upsample(scale_factor=2, mode='bilinear'),
30
+ nn.Conv2d(64, 32, kernel_size=3, padding=1),
31
+ nn.Upsample(scale_factor=2, mode='bilinear'),
32
+ nn.Conv2d(32, 1, kernel_size=1),
33
+ nn.Sigmoid()
34
+ )
35
+
36
+ def forward(self, x):
37
+ feats = self.backbone(x)
38
+ out = self.decoder(feats[-1])
39
+ out = nn.functional.interpolate(out, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False)
40
+ return out
41
+
42
+ # === Load Model ===
43
+ @st.cache_resource
44
+ def load_model():
45
+ model = MobileViTSegmentation()
46
+ state_dict = torch.load("model/mobilevit_teeth_segmentation.pth", map_location="cpu")
47
+ model.load_state_dict(state_dict)
48
+ model.eval()
49
+ return model
50
+
51
+ model = load_model()
52
+
53
+ # === Preprocessing ===
54
+ def preprocess_image(image: Image.Image):
55
+ image = image.convert("RGB").resize((256, 256))
56
+ arr = np.array(image).astype(np.float32) / 255.0
57
+ arr = np.transpose(arr, (2, 0, 1)) # HWC → CHW
58
+ tensor = torch.tensor(arr).unsqueeze(0) # Add batch dim
59
+ return tensor
60
+
61
+ # === Postprocessing: Overlay Mask ===
62
+ def overlay_mask(image_pil, mask_tensor, threshold=0.7):
63
+ image = np.array(image_pil.resize((256, 256)))
64
+ mask = mask_tensor.squeeze().detach().numpy()
65
+ mask_bin = (mask > threshold).astype(np.uint8) * 255
66
+
67
+ mask_color = np.zeros_like(image)
68
+ mask_color[..., 2] = mask_bin # Blue mask
69
+
70
+ overlayed = cv2.addWeighted(image, 1.0, mask_color, 0.5, 0)
71
+ return overlayed
72
+
73
+ # === UI ===
74
+ st.title("🦷 Tooth Segmentation with MobileViT")
75
+ st.write("Upload an image to segment the **visible teeth area** using a lightweight MobileViT segmentation model.")
76
+
77
+ uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
78
+
79
+ if uploaded_file:
80
+ image = Image.open(uploaded_file)
81
+ tensor = preprocess_image(image)
82
+
83
+ with st.spinner("Segmenting..."):
84
+ with torch.no_grad():
85
+ pred = model(tensor)[0]
86
+
87
+ overlayed_img = overlay_mask(image, pred)
88
 
89
+ col1, col2 = st.columns(2)
90
+ with col1:
91
+ st.image(image, caption="Original Image", use_container_width=True)
92
+ with col2:
93
+ st.image(overlayed_img, caption="Tooth Mask Overlay", use_container_width=True)