File size: 2,904 Bytes
db0960e
061386e
 
 
 
 
744e6f4
 
 
 
 
061386e
744e6f4
061386e
744e6f4
061386e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
744e6f4
061386e
 
744e6f4
 
 
061386e
 
 
 
 
744e6f4
 
 
 
 
061386e
744e6f4
 
061386e
744e6f4
061386e
 
744e6f4
 
 
 
 
061386e
744e6f4
 
 
061386e
744e6f4
 
 
 
 
 
db0960e
744e6f4
061386e
 
 
 
744e6f4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import streamlit as st
import torch
import torch.nn as nn
import timm
import numpy as np
from PIL import Image
import requests
from io import BytesIO
import torchvision.transforms as T
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download

# ========== Model Definition ==========
class MobileViTSegmentation(nn.Module):
    def __init__(self, encoder_name='mobilevit_s', pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(encoder_name, features_only=True, pretrained=pretrained)
        self.encoder_channels = self.backbone.feature_info.channels()

        self.decoder = nn.Sequential(
            nn.Conv2d(self.encoder_channels[-1], 128, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(32, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        feats = self.backbone(x)
        out = self.decoder(feats[-1])
        out = nn.functional.interpolate(out, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False)
        return out

# ========== Load Model ==========
@st.cache_resource
def load_model():
    checkpoint_path = hf_hub_download(repo_id="svsaurav95/ToothSegmentation", filename="mobilevit_teeth_segmentation.pth")
    model = MobileViTSegmentation(pretrained=False)
    model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
    model.eval()
    return model

model = load_model()

# ========== Image Transformation ==========
transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor()
])

# ========== Streamlit UI ==========
st.title("Tooth Segmentation using MobileViT")

uploaded_file = st.file_uploader("Upload a mouth image with visible teeth", type=["jpg", "jpeg", "png"])

if uploaded_file:
    image = Image.open(uploaded_file).convert("RGB")
    input_tensor = transform(image).unsqueeze(0)

    with torch.no_grad():
        pred_mask = model(input_tensor)[0, 0].numpy()

    # Post-processing
    pred_mask = (pred_mask > 0.7).astype(np.uint8) * 255
    pred_mask = Image.fromarray(pred_mask).resize(image.size)

    # Create overlay
    overlay = Image.new("RGBA", image.size, (0, 0, 255, 100))  # Blue translucent
    base = image.convert("RGBA")
    pred_mask_rgba = Image.new("L", image.size, 0)
    pred_mask_rgba.paste(255, mask=pred_mask)
    final = Image.composite(overlay, base, pred_mask_rgba)

    # Side-by-side display
    col1, col2 = st.columns(2)
    with col1:
        st.image(image, caption="Original Image", use_container_width=True)
    with col2:
        st.image(final, caption="Tooth Segmentation Overlay", use_container_width=True)