ToothSegmentation / src /streamlit_app.py
svsaurav95's picture
Update src/streamlit_app.py
744e6f4 verified
raw
history blame
2.9 kB
import streamlit as st
import torch
import torch.nn as nn
import timm
import numpy as np
from PIL import Image
import requests
from io import BytesIO
import torchvision.transforms as T
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
# ========== Model Definition ==========
class MobileViTSegmentation(nn.Module):
def __init__(self, encoder_name='mobilevit_s', pretrained=True):
super().__init__()
self.backbone = timm.create_model(encoder_name, features_only=True, pretrained=pretrained)
self.encoder_channels = self.backbone.feature_info.channels()
self.decoder = nn.Sequential(
nn.Conv2d(self.encoder_channels[-1], 128, kernel_size=3, padding=1),
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(32, 1, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
feats = self.backbone(x)
out = self.decoder(feats[-1])
out = nn.functional.interpolate(out, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False)
return out
# ========== Load Model ==========
@st.cache_resource
def load_model():
checkpoint_path = hf_hub_download(repo_id="svsaurav95/ToothSegmentation", filename="mobilevit_teeth_segmentation.pth")
model = MobileViTSegmentation(pretrained=False)
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
model.eval()
return model
model = load_model()
# ========== Image Transformation ==========
transform = T.Compose([
T.Resize((256, 256)),
T.ToTensor()
])
# ========== Streamlit UI ==========
st.title("Tooth Segmentation using MobileViT")
uploaded_file = st.file_uploader("Upload a mouth image with visible teeth", type=["jpg", "jpeg", "png"])
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
input_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
pred_mask = model(input_tensor)[0, 0].numpy()
# Post-processing
pred_mask = (pred_mask > 0.7).astype(np.uint8) * 255
pred_mask = Image.fromarray(pred_mask).resize(image.size)
# Create overlay
overlay = Image.new("RGBA", image.size, (0, 0, 255, 100)) # Blue translucent
base = image.convert("RGBA")
pred_mask_rgba = Image.new("L", image.size, 0)
pred_mask_rgba.paste(255, mask=pred_mask)
final = Image.composite(overlay, base, pred_mask_rgba)
# Side-by-side display
col1, col2 = st.columns(2)
with col1:
st.image(image, caption="Original Image", use_container_width=True)
with col2:
st.image(final, caption="Tooth Segmentation Overlay", use_container_width=True)