EfficientNet-B3 for Periapical Index (PAI) Scoring

This repository contains the official model weights for an EfficientNet-B3 model fine-tuned for Periapical Index (PAI) scoring of dental radiographs. PAI scoring is a standardized method for assessing apical periodontitis, rating pathology on a 1-5 scale.

This model was developed to enhance the consistency and efficiency of endodontic diagnosis by providing automated PAI classification. For a complete overview of the training, data handling, and XAI analysis methodology, please see the source code.

Model Details

Attribute Value
Architecture efficientnet_b3
Library timm
Input Size 300x300 pixels
Classes 5 (PAI Scores 1-5)
Normalization Mean [0.3785, 0.3785, 0.3785]
Normalization Std [0.1675, 0.1675, 0.1675]

How to Use the Model

The model checkpoint contains all necessary configuration (model_name, num_classes, etc.) for easy loading.

1. Installation

This model requires timm, torch, torchvision, and huggingface_hub.

pip install timm torch torchvision "huggingface_hub[cli]" pillow

2. Loading the Model and Preprocessing

The following snippet shows how to download the model from the Hugging Face Hub, load it with the correct configuration, and prepare an image for inference.

import torch
import timm
from torchvision import transforms
from PIL import Image
from huggingface_hub import hf_hub_download

def load_pai_model_from_hub(repo_id="geraldoslo/PAI-meets-AI", filename="efficientnet_b3_NVIDIA_H100_PCIe_MIG_1g.20gb_20250531_121520_best.pth"):
    """
    Downloads and loads the PAI classification model from the Hugging Face Hub.
    The checkpoint contains the model configuration, so it is loaded automatically.
    """
    # Download the checkpoint file
    checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Extract model configuration from the checkpoint
    model_config = checkpoint.get("model_config", {})
    model_name = model_config.get("model_name", "efficientnet_b3")
    num_classes = model_config.get("num_classes", 5)
    
    # Create the model using timm
    model = timm.create_model(model_name, pretrained=False, num_classes=num_classes)
    
    # Load the model's state dictionary
    model.load_state_dict(checkpoint["model_state_dict"])
    
    # Move model to the device and set to evaluation mode
    model.to(device)
    model.eval()
    
    print(f"Model '{model_name}' loaded successfully on '{device}'.")
    return model, device

def prepare_image(image_path):
    """
    Loads and preprocesses an image for the PAI model.
    """
    # Normalization parameters must match those used during training
    transform = transforms.Compose([
        transforms.Resize((300, 300)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.3785, 0.3785, 0.3785], std=[0.1675, 0.1675, 0.1675])
    ])
    
    image = Image.open(image_path).convert("RGB")
    return transform(image).unsqueeze(0) # Add batch dimension

# --- Example Usage ---
if __name__ == '__main__':
    # 1. Load the model
    pai_model, device = load_pai_model_from_hub()

    # 2. Prepare your input image
    # Replace "path/to/your/radiograph.png" with your actual image file
    try:
        input_tensor = prepare_image("path/to/your/radiograph.png")
    except FileNotFoundError:
        print("Error: Please replace 'path/to/your/radiograph.png' with a valid image file path.")
        exit()

    # 3. Run inference
    with torch.no_grad():
        output = pai_model(input_tensor.to(device))
        probabilities = torch.nn.functional.softmax(output, dim=1)
        predicted_class_idx = torch.argmax(probabilities, dim=1).item()
        confidence = probabilities[0][predicted_class_idx].item()

    # PAI scores are 1-based, so add 1 to the 0-indexed output
    predicted_pai_score = predicted_class_idx + 1

    print(f"\nPredicted PAI score: {predicted_pai_score}")
    print(f"Confidence: {confidence:.2%}")

3. Explainable AI (XAI)

The source code repository includes comprehensive Jupyter notebooks and utilities for generating Class Activation Maps (e.g., Grad-CAM, ScoreCAM, LayerCAM). These tools create heatmaps to visualize the regions of the radiograph that most influenced the model's prediction, aiding in model interpretability and validation.

Data and Intended Use

Data Usage Statement

The dataset of radiograph clips and PAI scores used for training and testing this model is a private clinical dataset and is not shared due to patient privacy and data protection regulations. Users can reproduce the training and inference pipelines on their own data but cannot access the original dataset.

Intended Use

This model is intended for research and educational purposes only. It is not a certified medical device and should not be used for clinical diagnosis or decision-making without further validation and regulatory approval. The model's performance may vary on datasets with different characteristics (e.g., different imaging equipment, patient populations, or preprocessing).

Citation

If you use this model in your research, please cite the GitHub repository:

@software{Torgersen_PAI-meets-AI_2025,
  author = {Torgersen, Gerald},
  title = {{PAI-meets-AI: A Deep Learning Framework for Periapical Index Classification}},
  month = {6},
  year = {2025},
  publisher = {GitHub},
  version = {1.0.0},
  url = {https://github.com/geraldOslo/PAI-meets-AI}
}
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for geraldOslo/pai-meets-ai

Finetuned
(7)
this model