import os
import sys
import shutil
import importlib.util
from io import BytesIO
from ultralytics import YOLO
from PIL import Image
import torch
# ─── FORCE CPU ONLY ─────────────────────────────────────────────────────────
torch.Tensor.cuda = lambda self, *args, **kwargs: self
torch.nn.Module.cuda = lambda self, *args, **kwargs: self
torch.cuda.synchronize = lambda *args, **kwargs: None
torch.cuda.is_available= lambda : False
torch.cuda.device_count= lambda : 0
_orig_to = torch.Tensor.to
def _to_cpu(self, *args, **kwargs):
new_args = []
for a in args:
if isinstance(a, str) and a.lower().startswith("cuda"):
new_args.append("cpu")
elif isinstance(a, torch.device) and a.type=="cuda":
new_args.append(torch.device("cpu"))
else:
new_args.append(a)
if "device" in kwargs:
dev = kwargs["device"]
if (isinstance(dev, str) and dev.lower().startswith("cuda")) or \
(isinstance(dev, torch.device) and dev.type=="cuda"):
kwargs["device"] = torch.device("cpu")
return _orig_to(self, *new_args, **kwargs)
torch.Tensor.to = _to_cpu
from torch.utils.data import DataLoader as _DL
def _dl0(ds, *a, **kw):
kw['num_workers'] = 0
return _DL(ds, *a, **kw)
import torch.utils.data as _du
_du.DataLoader = _dl0
import cv2
import numpy as np
import streamlit as st
from argparse import Namespace
# ─── DYNAMIC IMPORT ─────────────────────────────────────────────────────────
REPO = os.path.dirname(os.path.abspath(__file__))
sys.path.append(REPO)
models_dir = os.path.join(REPO, "models")
os.makedirs(models_dir, exist_ok=True)
open(os.path.join(models_dir, "__init__.py"), "a").close()
def load_mod(name, path):
spec = importlib.util.spec_from_file_location(name, path)
m = importlib.util.module_from_spec(spec)
spec.loader.exec_module(m)
sys.modules[name] = m
return m
dataset_mod = load_mod("dataset", os.path.join(REPO, "dataset.py"))
decoder_mod = load_mod("decoder", os.path.join(REPO, "decoder.py"))
draw_mod = load_mod("draw_points", os.path.join(REPO, "draw_points.py"))
test_mod = load_mod("test", os.path.join(REPO, "test.py"))
load_mod("models.dec_net", os.path.join(models_dir, "dec_net.py"))
load_mod("models.model_parts", os.path.join(models_dir, "model_parts.py"))
load_mod("models.resnet", os.path.join(models_dir, "resnet.py"))
load_mod("models.spinal_net", os.path.join(models_dir, "spinal_net.py"))
BaseDataset = dataset_mod.BaseDataset
Network = test_mod.Network
# ─── STREAMLIT UI ───────────────────────────────────────────────────────────
st.set_page_config(layout="wide", page_title="Vertebral Compression Fracture")
st.markdown(
"""
🦴 Vertebral Compression Fracture Detection 🖼️
""", unsafe_allow_html=True)
st.markdown("")
st.markdown("")
st.markdown("")
col1, col2, col3, col4 = st.columns(4)
with col4:
feature = st.selectbox(
"🔀 Select Feature",
["How to use", "AP - Detection", "LA - Image Segmetation", "Contract"],
index=3, # default to "AP"
help="Choose which view to display"
)
if feature == "How to use":
st.markdown("## 📖 How to use this app")
col1, col2, col3 = st.columns(3)
card_style = """
border:2px solid #00BFFF;
border-radius:10px;
padding:15px;
text-align:center;
background-color:#F0F8FF;
"""
title_style = "color:#000f14; margin-bottom:10px;"
body_style = "color:#000f14; text-align:left;"
with col1:
st.markdown(
f"""
Step 1️⃣
Go to AP - Detection or LA - Image Segmentation
Select a sample image or upload your own image file.
✅ Tip: Best with X-ray images with clear vertebra visibility.
""",
unsafe_allow_html=True
)
with col2:
st.markdown(
f"""
Step 2️⃣
Press the Enter button.
The system will process your image automatically.
⏳ Note: Processing time depends on image size.
""",
unsafe_allow_html=True
)
with col3:
st.markdown(
f"""
Step 3️⃣
See the prediction results:
1. Bounding boxes & landmarks (AP)
2. Segmentation masks (LA)
""",
unsafe_allow_html=True
)
st.markdown(" ")
st.info("สามารถเลือกฟีเจอร์ได้ผ่าน Select Feature โดยแต่ล่ะฟีเจอร์จะมีตัวอย่างกำกับให้ว่าเป็นยังไง")
# … (any code above)
elif feature == "AP - Detection":
uploaded = st.file_uploader("", type=["jpg", "jpeg", "png"])
orig_w = orig_h = None
img0 = None
run = st.button("Enter", use_container_width=True)
if "sample_img" not in st.session_state:
st.session_state.sample_img = None
with col1:
if st.button(" 1️⃣ Example", use_container_width=True):
st.session_state.sample_img = "image_1.jpg"
with col2:
if st.button(" 2️⃣ Example", use_container_width=True):
st.session_state.sample_img = "image_2.jpg"
with col3:
if st.button(" 3️⃣ Example", use_container_width=True):
st.session_state.sample_img = "image_3.jpg"
col4, col5, col6 = st.columns(3)
with col4:
st.subheader("1️⃣ Upload & Run")
sample_img = st.session_state.sample_img
if uploaded:
buf = uploaded.getvalue()
arr = np.frombuffer(buf, np.uint8)
img0 = cv2.imdecode(arr, cv2.IMREAD_COLOR)
orig_h, orig_w = img0.shape[:2]
st.image(cv2.cvtColor(img0, cv2.COLOR_BGR2RGB),
use_container_width=True)
elif sample_img is not None:
img_path = os.path.join(REPO, sample_img)
img0 = cv2.imread(img_path)
if img0 is not None:
orig_h, orig_w = img0.shape[:2]
st.image(cv2.cvtColor(img0, cv2.COLOR_BGR2RGB),
use_container_width=True)
else:
st.error(f"Cannot find {sample_img}")
with col5:
st.subheader("2️⃣ Predictions")
with col6:
st.subheader("3️⃣ Heatmap")
args = Namespace(
resume="model_30.pth",
data_dir=os.path.join(REPO, "dataPath"),
dataset="spinal",
phase="test",
input_h=1024,
input_w=512,
down_ratio=4,
num_classes=1,
K=17,
conf_thresh=0.2,
)
weights_dir = os.path.join(REPO, "weights_spinal")
os.makedirs(weights_dir, exist_ok=True)
src_ckpt = os.path.join(REPO, "model_backup", args.resume)
dst_ckpt = os.path.join(weights_dir, args.resume)
if os.path.isfile(src_ckpt) and not os.path.isfile(dst_ckpt):
shutil.copy(src_ckpt, dst_ckpt)
if img0 is not None and run and orig_w and orig_h:
name = (os.path.splitext(uploaded.name)[0]
if uploaded else os.path.splitext(sample_img)[0]) + ".jpg"
test_dir = os.path.join(args.data_dir, "data", "test")
os.makedirs(test_dir, exist_ok=True)
cv2.imwrite(os.path.join(test_dir, name), img0)
orig_init = BaseDataset.__init__
def patched_init(self, data_dir, phase,
input_h=None, input_w=None, down_ratio=4):
orig_init(self, data_dir, phase, input_h, input_w, down_ratio)
if phase == "test":
self.img_ids = [name]
BaseDataset.__init__ = patched_init
with st.spinner("Running model…"):
net = Network(args)
net.test(args, save=True)
out_dir = os.path.join(REPO, f"results_{args.dataset}")
pred_file = next(
f for f in os.listdir(out_dir)
if f.startswith(name) and f.endswith("_pred.jpg")
)
txtf = os.path.join(out_dir, f"{name}.txt")
imgf = os.path.join(out_dir, pred_file)
# ─── Annotated predictions ─────────────────────────────────────
ann = cv2.imread(imgf)
txt = np.loadtxt(txtf)
tlx, tly = txt[:,2].astype(int), txt[:,3].astype(int)
trx, try_ = txt[:,4].astype(int), txt[:,5].astype(int)
blx, bly = txt[:,6].astype(int), txt[:,7].astype(int)
brx, bry = txt[:,8].astype(int), txt[:,9].astype(int)
for x1, y1, x2, y2 in zip(tlx, tly, trx, try_):
cv2.line(ann, (x1, y1), (x2, y2), (255,255,0), 2)
for x1,y1,x2,y2,x3,y3,x4,y4 in zip(
tlx, tly, trx, try_, blx, bly, brx, bry
):
top_mid = np.array([(x1+x2)/2, (y1+y2)/2])
bot_mid = np.array([(x3+x4)/2, (y3+y4)/2])
p0 = tuple(top_mid.astype(int))
p1 = tuple(bot_mid.astype(int))
cv2.line(ann, p0, p1, (0,255,255), 2)
h_before = np.linalg.norm(bot_mid - top_mid)
h_after = 2 * int(h_before * 0.4)
pct = ((h_before - h_after) / h_before * 100) - 10
clr = (0,0,255) if pct > 40 else (
(0,165,255) if pct > 20 else (0,255,255))
text_pos = (x2 + 5, y2 - 5)
cv2.putText(
ann, f"{pct:.0f}%", text_pos,
cv2.FONT_HERSHEY_SIMPLEX, 0.5, clr, 2, cv2.LINE_AA
)
ann_resized = cv2.resize(
ann, (orig_w, orig_h),
interpolation=cv2.INTER_LINEAR
)
with col5:
st.image(
cv2.cvtColor(ann_resized, cv2.COLOR_BGR2RGB),
use_container_width=True
)
# ─── Heatmap overlay + connecting lines ─────────────────────────
base = cv2.imread(imgf)
H, W = base.shape[:2]
heat = np.zeros((H, W), np.float32)
cts = []
for (x1, y1), (x2, y2) in zip(zip(tlx, tly), zip(trx, try_)):
tm = np.array([(x1 + x2)/2, (y1 + y2)/2])
cts.append((int(tm[0]), int(tm[1])))
for cx, cy in cts:
blob = np.zeros_like(heat)
blob[cy, cx] = 1.0
heat += cv2.GaussianBlur(blob, (0,0), sigmaX=8, sigmaY=8)
heat /= heat.max() + 1e-8
hm8 = (heat * 255).astype(np.uint8)
hm_c = cv2.applyColorMap(hm8, cv2.COLORMAP_JET)
raw = cv2.imread(imgf, cv2.IMREAD_GRAYSCALE)
raw_b = cv2.cvtColor(raw, cv2.COLOR_GRAY2BGR)
overlay = cv2.addWeighted(raw_b, 0.6, hm_c, 0.4, 0)
for p1, p2 in zip(cts, cts[1:]):
cv2.line(overlay, p1, p2, (0,255,255), 2)
# ─── Cobb‑angle original logic ────────────────────────────────
vecs = np.diff(np.array(cts), axis=0)
angles = np.degrees(np.arctan2(vecs[:,1], vecs[:,0]))
idx_max = int(np.argmax(angles))
idx_min = int(np.argmin(angles))
cobb = abs(angles[idx_max] - angles[idx_min])
# ─── highlight apex of curvature ─────────────────────────────
# compute local curvature angles
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
unit = vecs / norms
dots = np.sum(unit[:-1] * unit[1:], axis=1)
dots = np.clip(dots, -1.0, 1.0)
thetas = np.degrees(np.arccos(dots))
apex_idx = int(np.argmax(thetas)) + 1 # vertex index
vx, vy = cts[apex_idx]
cv2.circle(overlay, (vx, vy), 15, (0, 0, 255), 2)
# ─── draw centered Cobb text ────────────────────────────────
text1 = "Cobb Angle"
text2 = f"{cobb:.1f}"
font = cv2.FONT_HERSHEY_SIMPLEX
scale, thickness = 1.0, 2
(w1,h1),_ = cv2.getTextSize(text1, font, scale, thickness)
(w2,h2),_ = cv2.getTextSize(text2, font, scale, thickness)
x1 = (W - w1)//2; y1 = H//2 - h1 - 10
x2 = (W - w2)//2; y2 = H//2 + h2 + 10
cv2.putText(overlay, text1, (x1, y1), font, scale, (0,255,255), thickness, cv2.LINE_AA)
cv2.putText(overlay, text2, (x2, y2), font, scale, (0,255,255), thickness, cv2.LINE_AA)
overlay_resized = cv2.resize(
overlay, (orig_w, orig_h),
interpolation=cv2.INTER_LINEAR
)
with col6:
st.image(
cv2.cvtColor(overlay_resized, cv2.COLOR_BGR2RGB),
use_container_width=True
)
elif feature == "LA - Image Segmetation":
uploaded = st.file_uploader("", type=["jpg", "jpeg", "png"])
img0 = None
# ─── Maintain selected sample in session state ─────────
if "sample_img_la" not in st.session_state:
st.session_state.sample_img_la = None
# ─── SAMPLE BUTTONS ─────────────────────────────────────
with col1:
if st.button(" 1️⃣ Example ", use_container_width=True):
st.session_state.sample_img_la = "image_1_la.jpg"
with col2:
if st.button(" 2️⃣ Example ", use_container_width=True):
st.session_state.sample_img_la = "image_2_la.jpg"
with col3:
if st.button(" 3️⃣ Example ", use_container_width=True):
st.session_state.sample_img_la = "image_3_la.jpg"
# ─── UI FOR UPLOAD + DISPLAY ───────────────────────────
run_la = st.button("Enter", use_container_width=True)
# ─── CONFIDENCE BANNER ─────────────────────────────────
col7, col8 = st.columns(2)
with col7:
st.subheader("🖼️ Original Image")
sample_img_la = st.session_state.sample_img_la
if uploaded:
buf = uploaded.getvalue()
img0 = Image.open(BytesIO(buf)).convert("RGB")
st.image(img0, caption="Uploaded Image", use_container_width=True)
elif sample_img_la is not None:
img_path = os.path.join(REPO, sample_img_la)
if os.path.isfile(img_path):
img0 = Image.open(img_path).convert("RGB")
st.image(img0, caption=f"Sample Image: {sample_img_la}", use_container_width=True)
else:
st.error(f"Cannot find {sample_img_la} in directory!")
with col8:
st.subheader("🔎 Predicted Image")
# ─── PREDICTION ────────────────────────────────────
if img0 is not None and run_la:
img_np = np.array(img0)
model = YOLO('best_100.pt') # path to your weights
with st.spinner("Running YOLO model…"):
results = model(img_np, imgsz=640)
# ─── Compute & Redisplay Confidence ────────────
# get all box confidences (if no boxes, empty array)
confidences = (results[0].boxes.conf.cpu().numpy() if hasattr(results[0].boxes, "conf") else np.array([]))
avg_conf = confidences.mean() if confidences.size > 0 else 0.0
# overwrite the placeholder banner with the real value
# ─── Show Segmentation ────────────────────────
pred_img = results[0].plot(boxes=False, probs=False)
st.image(pred_img, caption="Prediction Result", use_container_width=True)
st.markdown(
f""
f"✨ **Confidence Level:** {avg_conf*100:.1f}% ✨"
"
",
unsafe_allow_html=True
)
elif feature == "Contract":
# shared styles
card_style = """
border:2px solid #0080FF;
border-radius:10px;
padding:15px;
text-align:center;
background-color:#F0F8FF;
"""
title_style = "color:#00BFFF; margin-bottom:8px;" # names
body_style = "color:#87CEEB; text-decoration:none;"
with col1:
st.image("dev_1.jpg", caption=None, use_container_width=True)
st.markdown(
f"""
""",
unsafe_allow_html=True
)
with col2:
st.image("dev_2.jpg", caption=None, use_container_width=True)
st.markdown(
f"""
""",
unsafe_allow_html=True
)
with col3:
st.image("dev_3.jpg", caption=None, use_container_width=True)
st.markdown(
f"""
""",
unsafe_allow_html=True
)