EfficientNet-B3 for Periapical Index (PAI) Scoring
This repository contains the official model weights for an EfficientNet-B3 model fine-tuned for Periapical Index (PAI) scoring of dental radiographs. PAI scoring is a standardized method for assessing apical periodontitis, rating pathology on a 1-5 scale.
This model was developed to enhance the consistency and efficiency of endodontic diagnosis by providing automated PAI classification. For a complete overview of the training, data handling, and XAI analysis methodology, please see the source code.
- Source Code: https://github.com/geraldOslo/PAI-meets-AI
Model Details
Attribute | Value |
---|---|
Architecture | efficientnet_b3 |
Library | timm |
Input Size | 300x300 pixels |
Classes | 5 (PAI Scores 1-5) |
Normalization Mean | [0.3785, 0.3785, 0.3785] |
Normalization Std | [0.1675, 0.1675, 0.1675] |
How to Use the Model
The model checkpoint contains all necessary configuration (model_name
, num_classes
, etc.) for easy loading.
1. Installation
This model requires timm
, torch
, torchvision
, and huggingface_hub
.
pip install timm torch torchvision "huggingface_hub[cli]" pillow
2. Loading the Model and Preprocessing
The following snippet shows how to download the model from the Hugging Face Hub, load it with the correct configuration, and prepare an image for inference.
import torch
import timm
from torchvision import transforms
from PIL import Image
from huggingface_hub import hf_hub_download
def load_pai_model_from_hub(repo_id="geraldoslo/PAI-meets-AI", filename="efficientnet_b3_NVIDIA_H100_PCIe_MIG_1g.20gb_20250531_121520_best.pth"):
"""
Downloads and loads the PAI classification model from the Hugging Face Hub.
The checkpoint contains the model configuration, so it is loaded automatically.
"""
# Download the checkpoint file
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
# Extract model configuration from the checkpoint
model_config = checkpoint.get("model_config", {})
model_name = model_config.get("model_name", "efficientnet_b3")
num_classes = model_config.get("num_classes", 5)
# Create the model using timm
model = timm.create_model(model_name, pretrained=False, num_classes=num_classes)
# Load the model's state dictionary
model.load_state_dict(checkpoint["model_state_dict"])
# Move model to the device and set to evaluation mode
model.to(device)
model.eval()
print(f"Model '{model_name}' loaded successfully on '{device}'.")
return model, device
def prepare_image(image_path):
"""
Loads and preprocesses an image for the PAI model.
"""
# Normalization parameters must match those used during training
transform = transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.3785, 0.3785, 0.3785], std=[0.1675, 0.1675, 0.1675])
])
image = Image.open(image_path).convert("RGB")
return transform(image).unsqueeze(0) # Add batch dimension
# --- Example Usage ---
if __name__ == '__main__':
# 1. Load the model
pai_model, device = load_pai_model_from_hub()
# 2. Prepare your input image
# Replace "path/to/your/radiograph.png" with your actual image file
try:
input_tensor = prepare_image("path/to/your/radiograph.png")
except FileNotFoundError:
print("Error: Please replace 'path/to/your/radiograph.png' with a valid image file path.")
exit()
# 3. Run inference
with torch.no_grad():
output = pai_model(input_tensor.to(device))
probabilities = torch.nn.functional.softmax(output, dim=1)
predicted_class_idx = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0][predicted_class_idx].item()
# PAI scores are 1-based, so add 1 to the 0-indexed output
predicted_pai_score = predicted_class_idx + 1
print(f"\nPredicted PAI score: {predicted_pai_score}")
print(f"Confidence: {confidence:.2%}")
3. Explainable AI (XAI)
The source code repository includes comprehensive Jupyter notebooks and utilities for generating Class Activation Maps (e.g., Grad-CAM, ScoreCAM, LayerCAM). These tools create heatmaps to visualize the regions of the radiograph that most influenced the model's prediction, aiding in model interpretability and validation.
Data and Intended Use
Data Usage Statement
The dataset of radiograph clips and PAI scores used for training and testing this model is a private clinical dataset and is not shared due to patient privacy and data protection regulations. Users can reproduce the training and inference pipelines on their own data but cannot access the original dataset.
Intended Use
This model is intended for research and educational purposes only. It is not a certified medical device and should not be used for clinical diagnosis or decision-making without further validation and regulatory approval. The model's performance may vary on datasets with different characteristics (e.g., different imaging equipment, patient populations, or preprocessing).
Citation
If you use this model in your research, please cite the GitHub repository:
@software{Torgersen_PAI-meets-AI_2025,
author = {Torgersen, Gerald},
title = {{PAI-meets-AI: A Deep Learning Framework for Periapical Index Classification}},
month = {6},
year = {2025},
publisher = {GitHub},
version = {1.0.0},
url = {https://github.com/geraldOslo/PAI-meets-AI}
}
- Downloads last month
- -
Model tree for geraldOslo/pai-meets-ai
Base model
google/efficientnet-b3