PCOS Detection with Explainable AI

A deep learning model for Polycystic Ovary Syndrome (PCOS) detection from ultrasound images with Grad-CAM visualization for clinical interpretability.

Model Overview

  • Architecture: Dual-path CNN with multi-head attention
  • Input: 224Γ—224 RGB ultrasound images
  • Output: Binary classification (PCOS-positive / Healthy)
  • Accuracy: ~95%+ on test set
  • XAI: Grad-CAM heatmaps for interpretability

πŸš€ Quick Start

pip install tensorflow opencv-python matplotlib numpy requests huggingface-hub

Complete Working Example

# ============================================================
# πŸ” PCOS Prediction + Grad-CAM (HF VERSION)
# ============================================================

import numpy as np
import cv2
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import Model, Input
from tensorflow.keras.layers import (
    Conv2D, MaxPooling2D, Flatten, Dense,
    Lambda, Reshape, Concatenate,
    MultiHeadAttention, GlobalAveragePooling1D
)
import requests
from huggingface_hub import hf_hub_download

# ============================================================
# Config
# ============================================================
IMG_SIZE = (224, 224)
HF_MODEL_REPO = "Dehsahk-AI/Pcos-Detect"
MODEL_FILENAME = "best_pcos_model.h5"
IMAGE_URL = "https://example.com/ultrasound.jpg"  # Your image URL
CLASS_NAMES = ["infected", "noninfected"]

# ============================================================
# Download model from HF
# ============================================================
MODEL_PATH = hf_hub_download(repo_id=HF_MODEL_REPO, filename=MODEL_FILENAME)
print(f" Model downloaded to: {MODEL_PATH}")

# ============================================================
# Custom Lambda Functions
# ============================================================
def split_image(image):
    upper = image[:, :IMG_SIZE[0]//2, :, :]
    lower = image[:, IMG_SIZE[0]//2:, :, :]
    return upper, lower

def flip_lower(lower_half):
    return tf.image.flip_left_right(lower_half)

# ============================================================
# Rebuild Model Architecture
# ============================================================
input_layer = Input(shape=(224,224,3))

upper_half, lower_half = Lambda(split_image)(input_layer)
lower_half = Lambda(flip_lower)(lower_half)

# Upper CNN
u = Conv2D(32, 3, activation="relu", padding="same")(upper_half)
u = MaxPooling2D(2)(u)
u = Conv2D(64, 3, activation="relu", padding="same")(u)
u = MaxPooling2D(2)(u)
u = Conv2D(128, 3, activation="relu", padding="same", name="upper_last_conv")(u)
u = MaxPooling2D(2)(u)
u = Flatten()(u)

# Lower CNN
l = Conv2D(32, 3, activation="relu", padding="same")(lower_half)
l = MaxPooling2D(2)(l)
l = Conv2D(64, 3, activation="relu", padding="same")(l)
l = MaxPooling2D(2)(l)
l = Conv2D(128, 3, activation="relu", padding="same", name="lower_last_conv")(l)
l = MaxPooling2D(2)(l)
l = Flatten()(l)

u_dense = Dense(512, activation="relu")(u)
l_dense = Dense(512, activation="relu")(l)

u_r = Reshape((1,512))(u_dense)
l_r = Reshape((1,512))(l_dense)

concat = Concatenate(axis=1)([u_r, l_r])

att = MultiHeadAttention(num_heads=4, key_dim=64)(concat, concat)
att = GlobalAveragePooling1D()(att)

fc = Dense(256, activation="relu")(att)
fc = Dense(128, activation="relu")(fc)

# Logits for Grad-CAM
logits = Dense(2, name="logits")(fc)
output = tf.keras.layers.Activation('softmax', name='softmax')(logits)

model = Model(input_layer, output)
model.load_weights(MODEL_PATH)
print(" Weights loaded successfully")

# ============================================================
# Load & Preprocess Image
# ============================================================
response = requests.get(IMAGE_URL)
img_array_raw = np.asarray(bytearray(response.content), dtype=np.uint8)
img = cv2.imdecode(img_array_raw, cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, IMG_SIZE)
img = img.astype(np.float32) / 255.0
img_array = np.expand_dims(img, axis=0)

# ============================================================
# Prediction
# ============================================================
pred = model.predict(img_array, verbose=0)[0]
pred_class = np.argmax(pred)
confidence = pred[pred_class]

print(f"\n Prediction: {CLASS_NAMES[pred_class]}")
print(f" Confidence: {confidence:.2%}")

# ============================================================
# Grad-CAM
# ============================================================
def gradcam(img_array, model, layer_name, pred_index):
    logits_layer = model.get_layer('logits')
    grad_model = Model(
        model.input,
        [model.get_layer(layer_name).output, logits_layer.output]
    )

    with tf.GradientTape() as tape:
        conv_out, logits = grad_model(img_array)
        loss = logits[:, pred_index]

    grads = tape.gradient(loss, conv_out)
    pooled = tf.reduce_mean(grads, axis=(0,1,2))
    conv_out = conv_out[0]

    heatmap = conv_out @ pooled[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)
    heatmap = tf.maximum(heatmap, 0)

    if tf.reduce_max(heatmap) > 0:
        heatmap /= tf.reduce_max(heatmap)

    return heatmap.numpy()

upper = gradcam(img_array, model, "upper_last_conv", pred_class)
lower = gradcam(img_array, model, "lower_last_conv", pred_class)

h = IMG_SIZE[0] // 2
upper = cv2.resize(upper, (IMG_SIZE[1], h))
lower = cv2.resize(lower, (IMG_SIZE[1], h))
lower = cv2.flip(lower, 1)

heatmap = np.vstack([upper, lower])

heatmap_color = cv2.applyColorMap(np.uint8(255*heatmap), cv2.COLORMAP_JET)
heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB) / 255.0

overlay = 0.5 * heatmap_color + 0.5 * img

# ============================================================
# Visualization
# ============================================================
plt.figure(figsize=(15,5))

plt.subplot(1,3,1)
plt.imshow(img)
plt.title("Original")
plt.axis("off")

plt.subplot(1,3,2)
plt.imshow(heatmap, cmap="jet")
plt.title("Grad-CAM")
plt.axis("off")

plt.subplot(1,3,3)
plt.imshow(overlay)
plt.title(f"{CLASS_NAMES[pred_class]} ({confidence:.2%})")
plt.axis("off")

plt.tight_layout()
plt.show()

Load from Local File

# Replace URL loading with:
img = cv2.imread('path/to/ultrasound.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, IMG_SIZE)
img = img.astype(np.float32) / 255.0
img_array = np.expand_dims(img, axis=0)

Understanding Grad-CAM Output

  • Red/Hot regions: High importance for prediction (follicles, cysts)
  • Blue/Cool regions: Low influence on decision
  • Dual visualization: Separate heatmaps for upper and lower ovarian regions

Model Architecture

Input (224Γ—224Γ—3)
β”œβ”€β”€ Split horizontally (upper/lower)
β”œβ”€β”€ Upper Path: Conv32 β†’ Conv64 β†’ Conv128 β†’ Dense512
β”œβ”€β”€ Lower Path: Conv32 β†’ Conv64 β†’ Conv128 β†’ Dense512
β”œβ”€β”€ Multi-Head Attention (4 heads, dim=64)
└── Classification: Dense256 β†’ Dense128 β†’ Dense2

Key Features:

  • Dual-path CNN for separate ovarian region analysis
  • Lower region flipped for symmetry normalization
  • Multi-head attention for feature fusion
  • Logits-based Grad-CAM (fixes saturated softmax gradients)

Dataset

  • Total: 11,784 ultrasound images
  • PCOS-positive: 6,784 images (57.5%)
  • Healthy: 5,000 images (42.5%)
  • Source: 3 clinics (2018-2022), expert-annotated
  • Dataset: PCOS XAI Ultrasound

Important Notes

Clinical Use:

  • Research purposes only - NOT FDA approved
  • Not a diagnostic tool - requires professional validation
  • Must be validated on local datasets before clinical deployment

Technical:

  • Fixed 224Γ—224 input size required
  • RGB images only
  • Model performance may vary across different ultrasound machines

Citation

@misc{pcos_xai_2024,
  title={PCOS Detection with Explainable AI},
  author={Dehsahk-AI},
  year={2025},
  url={https://huggingface.co/Dehsahk-AI/Pcos-Detect}
}

License

MIT License - See LICENSE file for details.


Model Version: 1.0 | Last Updated: December 2025 license: MIT

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support