import os import sys import shutil import importlib.util from io import BytesIO from ultralytics import YOLO from PIL import Image import torch # ─── FORCE CPU ONLY ───────────────────────────────────────────────────────── torch.Tensor.cuda = lambda self, *args, **kwargs: self torch.nn.Module.cuda = lambda self, *args, **kwargs: self torch.cuda.synchronize = lambda *args, **kwargs: None torch.cuda.is_available= lambda : False torch.cuda.device_count= lambda : 0 _orig_to = torch.Tensor.to def _to_cpu(self, *args, **kwargs): new_args = [] for a in args: if isinstance(a, str) and a.lower().startswith("cuda"): new_args.append("cpu") elif isinstance(a, torch.device) and a.type=="cuda": new_args.append(torch.device("cpu")) else: new_args.append(a) if "device" in kwargs: dev = kwargs["device"] if (isinstance(dev, str) and dev.lower().startswith("cuda")) or \ (isinstance(dev, torch.device) and dev.type=="cuda"): kwargs["device"] = torch.device("cpu") return _orig_to(self, *new_args, **kwargs) torch.Tensor.to = _to_cpu from torch.utils.data import DataLoader as _DL def _dl0(ds, *a, **kw): kw['num_workers'] = 0 return _DL(ds, *a, **kw) import torch.utils.data as _du _du.DataLoader = _dl0 import cv2 import numpy as np import streamlit as st from argparse import Namespace # ─── DYNAMIC IMPORT ───────────────────────────────────────────────────────── REPO = os.path.dirname(os.path.abspath(__file__)) sys.path.append(REPO) models_dir = os.path.join(REPO, "models") os.makedirs(models_dir, exist_ok=True) open(os.path.join(models_dir, "__init__.py"), "a").close() def load_mod(name, path): spec = importlib.util.spec_from_file_location(name, path) m = importlib.util.module_from_spec(spec) spec.loader.exec_module(m) sys.modules[name] = m return m dataset_mod = load_mod("dataset", os.path.join(REPO, "dataset.py")) decoder_mod = load_mod("decoder", os.path.join(REPO, "decoder.py")) draw_mod = load_mod("draw_points", os.path.join(REPO, "draw_points.py")) test_mod = load_mod("test", os.path.join(REPO, "test.py")) load_mod("models.dec_net", os.path.join(models_dir, "dec_net.py")) load_mod("models.model_parts", os.path.join(models_dir, "model_parts.py")) load_mod("models.resnet", os.path.join(models_dir, "resnet.py")) load_mod("models.spinal_net", os.path.join(models_dir, "spinal_net.py")) BaseDataset = dataset_mod.BaseDataset Network = test_mod.Network # ─── STREAMLIT UI ─────────────────────────────────────────────────────────── st.set_page_config(layout="wide", page_title="Vertebral Compression Fracture") st.markdown( """

🦴 Vertebral Compression Fracture Detection 🖼️

""", unsafe_allow_html=True) st.markdown("") st.markdown("") st.markdown("") col1, col2, col3, col4 = st.columns(4) with col4: feature = st.selectbox( "🔀 Select Feature", ["How to use", "AP - Detection", "LA - Image Segmetation", "Contract"], index=3, # default to "AP" help="Choose which view to display" ) if feature == "How to use": st.markdown("## 📖 How to use this app") col1, col2, col3 = st.columns(3) card_style = """ border:2px solid #00BFFF; border-radius:10px; padding:15px; text-align:center; background-color:#F0F8FF; """ title_style = "color:#000f14; margin-bottom:10px;" body_style = "color:#000f14; text-align:left;" with col1: st.markdown( f"""

Step 1️⃣

Go to AP - Detection or LA - Image Segmentation

Select a sample image or upload your own image file.

✅ Tip: Best with X-ray images with clear vertebra visibility.

""", unsafe_allow_html=True ) with col2: st.markdown( f"""

Step 2️⃣

Press the Enter button.

The system will process your image automatically.

⏳ Note: Processing time depends on image size.

""", unsafe_allow_html=True ) with col3: st.markdown( f"""

Step 3️⃣

See the prediction results:

1. Bounding boxes & landmarks (AP)

2. Segmentation masks (LA)

""", unsafe_allow_html=True ) st.markdown(" ") st.info("สามารถเลือกฟีเจอร์ได้ผ่าน Select Feature โดยแต่ล่ะฟีเจอร์จะมีตัวอย่างกำกับให้ว่าเป็นยังไง") # … (any code above) elif feature == "AP - Detection": uploaded = st.file_uploader("", type=["jpg", "jpeg", "png"]) orig_w = orig_h = None img0 = None run = st.button("Enter", use_container_width=True) if "sample_img" not in st.session_state: st.session_state.sample_img = None with col1: if st.button(" 1️⃣ Example", use_container_width=True): st.session_state.sample_img = "image_1.jpg" with col2: if st.button(" 2️⃣ Example", use_container_width=True): st.session_state.sample_img = "image_2.jpg" with col3: if st.button(" 3️⃣ Example", use_container_width=True): st.session_state.sample_img = "image_3.jpg" col4, col5, col6 = st.columns(3) with col4: st.subheader("1️⃣ Upload & Run") sample_img = st.session_state.sample_img if uploaded: buf = uploaded.getvalue() arr = np.frombuffer(buf, np.uint8) img0 = cv2.imdecode(arr, cv2.IMREAD_COLOR) orig_h, orig_w = img0.shape[:2] st.image(cv2.cvtColor(img0, cv2.COLOR_BGR2RGB), use_container_width=True) elif sample_img is not None: img_path = os.path.join(REPO, sample_img) img0 = cv2.imread(img_path) if img0 is not None: orig_h, orig_w = img0.shape[:2] st.image(cv2.cvtColor(img0, cv2.COLOR_BGR2RGB), use_container_width=True) else: st.error(f"Cannot find {sample_img}") with col5: st.subheader("2️⃣ Predictions") with col6: st.subheader("3️⃣ Heatmap") args = Namespace( resume="model_30.pth", data_dir=os.path.join(REPO, "dataPath"), dataset="spinal", phase="test", input_h=1024, input_w=512, down_ratio=4, num_classes=1, K=17, conf_thresh=0.2, ) weights_dir = os.path.join(REPO, "weights_spinal") os.makedirs(weights_dir, exist_ok=True) src_ckpt = os.path.join(REPO, "model_backup", args.resume) dst_ckpt = os.path.join(weights_dir, args.resume) if os.path.isfile(src_ckpt) and not os.path.isfile(dst_ckpt): shutil.copy(src_ckpt, dst_ckpt) if img0 is not None and run and orig_w and orig_h: name = (os.path.splitext(uploaded.name)[0] if uploaded else os.path.splitext(sample_img)[0]) + ".jpg" test_dir = os.path.join(args.data_dir, "data", "test") os.makedirs(test_dir, exist_ok=True) cv2.imwrite(os.path.join(test_dir, name), img0) orig_init = BaseDataset.__init__ def patched_init(self, data_dir, phase, input_h=None, input_w=None, down_ratio=4): orig_init(self, data_dir, phase, input_h, input_w, down_ratio) if phase == "test": self.img_ids = [name] BaseDataset.__init__ = patched_init with st.spinner("Running model…"): net = Network(args) net.test(args, save=True) out_dir = os.path.join(REPO, f"results_{args.dataset}") pred_file = next( f for f in os.listdir(out_dir) if f.startswith(name) and f.endswith("_pred.jpg") ) txtf = os.path.join(out_dir, f"{name}.txt") imgf = os.path.join(out_dir, pred_file) # ─── Annotated predictions ───────────────────────────────────── ann = cv2.imread(imgf) txt = np.loadtxt(txtf) tlx, tly = txt[:,2].astype(int), txt[:,3].astype(int) trx, try_ = txt[:,4].astype(int), txt[:,5].astype(int) blx, bly = txt[:,6].astype(int), txt[:,7].astype(int) brx, bry = txt[:,8].astype(int), txt[:,9].astype(int) for x1, y1, x2, y2 in zip(tlx, tly, trx, try_): cv2.line(ann, (x1, y1), (x2, y2), (255,255,0), 2) for x1,y1,x2,y2,x3,y3,x4,y4 in zip( tlx, tly, trx, try_, blx, bly, brx, bry ): top_mid = np.array([(x1+x2)/2, (y1+y2)/2]) bot_mid = np.array([(x3+x4)/2, (y3+y4)/2]) p0 = tuple(top_mid.astype(int)) p1 = tuple(bot_mid.astype(int)) cv2.line(ann, p0, p1, (0,255,255), 2) h_before = np.linalg.norm(bot_mid - top_mid) h_after = 2 * int(h_before * 0.4) pct = ((h_before - h_after) / h_before * 100) - 10 clr = (0,0,255) if pct > 40 else ( (0,165,255) if pct > 20 else (0,255,255)) text_pos = (x2 + 5, y2 - 5) cv2.putText( ann, f"{pct:.0f}%", text_pos, cv2.FONT_HERSHEY_SIMPLEX, 0.5, clr, 2, cv2.LINE_AA ) ann_resized = cv2.resize( ann, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR ) with col5: st.image( cv2.cvtColor(ann_resized, cv2.COLOR_BGR2RGB), use_container_width=True ) # ─── Heatmap overlay + connecting lines ───────────────────────── base = cv2.imread(imgf) H, W = base.shape[:2] heat = np.zeros((H, W), np.float32) cts = [] for (x1, y1), (x2, y2) in zip(zip(tlx, tly), zip(trx, try_)): tm = np.array([(x1 + x2)/2, (y1 + y2)/2]) cts.append((int(tm[0]), int(tm[1]))) for cx, cy in cts: blob = np.zeros_like(heat) blob[cy, cx] = 1.0 heat += cv2.GaussianBlur(blob, (0,0), sigmaX=8, sigmaY=8) heat /= heat.max() + 1e-8 hm8 = (heat * 255).astype(np.uint8) hm_c = cv2.applyColorMap(hm8, cv2.COLORMAP_JET) raw = cv2.imread(imgf, cv2.IMREAD_GRAYSCALE) raw_b = cv2.cvtColor(raw, cv2.COLOR_GRAY2BGR) overlay = cv2.addWeighted(raw_b, 0.6, hm_c, 0.4, 0) for p1, p2 in zip(cts, cts[1:]): cv2.line(overlay, p1, p2, (0,255,255), 2) # ─── Cobb‑angle original logic ──────────────────────────────── vecs = np.diff(np.array(cts), axis=0) angles = np.degrees(np.arctan2(vecs[:,1], vecs[:,0])) idx_max = int(np.argmax(angles)) idx_min = int(np.argmin(angles)) cobb = abs(angles[idx_max] - angles[idx_min]) # ─── highlight apex of curvature ───────────────────────────── # compute local curvature angles norms = np.linalg.norm(vecs, axis=1, keepdims=True) unit = vecs / norms dots = np.sum(unit[:-1] * unit[1:], axis=1) dots = np.clip(dots, -1.0, 1.0) thetas = np.degrees(np.arccos(dots)) apex_idx = int(np.argmax(thetas)) + 1 # vertex index vx, vy = cts[apex_idx] cv2.circle(overlay, (vx, vy), 15, (0, 0, 255), 2) # ─── draw centered Cobb text ──────────────────────────────── text1 = "Cobb Angle" text2 = f"{cobb:.1f}" font = cv2.FONT_HERSHEY_SIMPLEX scale, thickness = 1.0, 2 (w1,h1),_ = cv2.getTextSize(text1, font, scale, thickness) (w2,h2),_ = cv2.getTextSize(text2, font, scale, thickness) x1 = (W - w1)//2; y1 = H//2 - h1 - 10 x2 = (W - w2)//2; y2 = H//2 + h2 + 10 cv2.putText(overlay, text1, (x1, y1), font, scale, (0,255,255), thickness, cv2.LINE_AA) cv2.putText(overlay, text2, (x2, y2), font, scale, (0,255,255), thickness, cv2.LINE_AA) overlay_resized = cv2.resize( overlay, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR ) with col6: st.image( cv2.cvtColor(overlay_resized, cv2.COLOR_BGR2RGB), use_container_width=True ) elif feature == "LA - Image Segmetation": uploaded = st.file_uploader("", type=["jpg", "jpeg", "png"]) img0 = None # ─── Maintain selected sample in session state ───────── if "sample_img_la" not in st.session_state: st.session_state.sample_img_la = None # ─── SAMPLE BUTTONS ───────────────────────────────────── with col1: if st.button(" 1️⃣ Example ", use_container_width=True): st.session_state.sample_img_la = "image_1_la.jpg" with col2: if st.button(" 2️⃣ Example ", use_container_width=True): st.session_state.sample_img_la = "image_2_la.jpg" with col3: if st.button(" 3️⃣ Example ", use_container_width=True): st.session_state.sample_img_la = "image_3_la.jpg" # ─── UI FOR UPLOAD + DISPLAY ─────────────────────────── run_la = st.button("Enter", use_container_width=True) # ─── CONFIDENCE BANNER ───────────────────────────────── col7, col8 = st.columns(2) with col7: st.subheader("🖼️ Original Image") sample_img_la = st.session_state.sample_img_la if uploaded: buf = uploaded.getvalue() img0 = Image.open(BytesIO(buf)).convert("RGB") st.image(img0, caption="Uploaded Image", use_container_width=True) elif sample_img_la is not None: img_path = os.path.join(REPO, sample_img_la) if os.path.isfile(img_path): img0 = Image.open(img_path).convert("RGB") st.image(img0, caption=f"Sample Image: {sample_img_la}", use_container_width=True) else: st.error(f"Cannot find {sample_img_la} in directory!") with col8: st.subheader("🔎 Predicted Image") # ─── PREDICTION ──────────────────────────────────── if img0 is not None and run_la: img_np = np.array(img0) model = YOLO('best_100.pt') # path to your weights with st.spinner("Running YOLO model…"): results = model(img_np, imgsz=640) # ─── Compute & Redisplay Confidence ──────────── # get all box confidences (if no boxes, empty array) confidences = (results[0].boxes.conf.cpu().numpy() if hasattr(results[0].boxes, "conf") else np.array([])) avg_conf = confidences.mean() if confidences.size > 0 else 0.0 # overwrite the placeholder banner with the real value # ─── Show Segmentation ──────────────────────── pred_img = results[0].plot(boxes=False, probs=False) st.image(pred_img, caption="Prediction Result", use_container_width=True) st.markdown( f"
" f"✨ **Confidence Level:** {avg_conf*100:.1f}% ✨" "
", unsafe_allow_html=True ) elif feature == "Contract": # shared styles card_style = """ border:2px solid #0080FF; border-radius:10px; padding:15px; text-align:center; background-color:#F0F8FF; """ title_style = "color:#00BFFF; margin-bottom:8px;" # names body_style = "color:#87CEEB; text-decoration:none;" with col1: st.image("dev_1.jpg", caption=None, use_container_width=True) st.markdown( f"""

Thitsanapat S.

🔗 Facebook Profile
""", unsafe_allow_html=True ) with col2: st.image("dev_2.jpg", caption=None, use_container_width=True) st.markdown( f"""

Santipab T.

🔗 Facebook Profile
""", unsafe_allow_html=True ) with col3: st.image("dev_3.jpg", caption=None, use_container_width=True) st.markdown( f"""

Suphanat K.

🔗 Facebook Profile
""", unsafe_allow_html=True )