Edit model card

ResNet18 Pneumonia Detection Model

This model is a fine-tuned version of the ResNet18 architecture for pneumonia detection. It was trained on the Kaggle Chest X-ray Pneumonia dataset, which includes images of normal lungs and lungs with pneumonia. The model is capable of distinguishing between Pneumonia and Normal chest X-rays.

Model Details

  • Model Architecture: ResNet18
  • Input Size: 224 x 224
  • Number of Classes: 2 (Pneumonia, Normal)
  • Framework: PyTorch
  • Training Dataset: Kaggle Chest X-ray Pneumonia Dataset
  • Library: PyTorch

Model Performance

  • Accuracy: 83.3%
  • Loss: 0.2459

Intended Use

This model is designed to assist healthcare professionals in identifying pneumonia from chest X-ray images. It should not be used as a sole diagnostic tool but as a supplement to medical expertise.

Training Details

The model was trained using the following setup:

  • Architecture: ResNet18 (Pre-trained on ImageNet)
  • Optimizer: SGD (Stochastic Gradient Descent)
    • Learning Rate: 0.001
    • Momentum: 0.9
  • Loss Function: CrossEntropyLoss
  • Batch Size: 32
  • Data Augmentation:
    • Random Rotation (±30 degrees)
    • Random Zoom (20%)
    • Random Horizontal Shift (±10% width)
    • Random Vertical Shift (±10% height)
    • Random Horizontal Flip
  • Training Epochs: 1
  • Evaluation Metric: Cross Entropy Loss

Augmentation Details

The dataset was augmented during training with the following transformations:

  • Randomly rotated some training images by 30 degrees
  • Randomly zoomed some training images by 20%
  • Randomly shifted images horizontally by 10% of the width
  • Randomly shifted images vertically by 10% of the height
  • Randomly flipped images horizontally

How to Use the Model

You can use this model with the transformers and torch libraries.

import torch
from huggingface_hub import hf_hub_download
from torchvision import transforms
from PIL import Image
import requests

# Download the model weights from Hugging Face Hub
model_path = hf_hub_download(repo_id="izeeek/resnet18_pneumonia_classifier", filename="resnet18_pneumonia_classifier.pth")

# Load the model architecture (ResNet18)
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)

# Adjust the final layer for binary classification (if necessary)
model.fc = torch.nn.Linear(model.fc.in_features, 2)

# Load the downloaded weights
model.load_state_dict(torch.load(model_path))

# Set the model to evaluation mode
model.eval()

# Image preprocessing
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Sample Image (replace with your own image URL)
url = 'https://storage.googleapis.com/kagglesdsdata/datasets/17810/23812/chest_xray/test/NORMAL/IM-0005-0001.jpeg?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=databundle-worker-v2%40kaggle-161607.iam.gserviceaccount.com%2F20240913%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20240913T014624Z&X-Goog-Expires=345600&X-Goog-SignedHeaders=host&X-Goog-Signature=1f6b37d181f12d083ffc951657e85fea087bb4e81ab955ec955dafcdae49c0d53ce20bc0be93605e2672b9bdd59e752eba9d5a3a0da2e3b3a03c888580b88d63d87611b4e4cec8b8802d53abd53fda165dd04765b8d9f30ddd4e908cd7a2a389ce8244fca7bfa36b3c9cff79d7c5e3f9ee7d59d5b9ef97a2e5c083997892ee3023302313fafff48ded58232db57d6affcfaee704eebba55f2b0abac40b14a38137275ad19cdb1b787930d134f7c30710e29c409bd765ca02e46851470a871cc697f614d464086373f43f5462f241eaf023cfd31e217d7b11e24e1ff34857deb200f5dc1a8c28c8115048ee840be8481f1bd79a2d8e2de1b30cb71420c007d32c'
img = Image.open(requests.get(url, stream=True).raw)

# Preprocess the image
input_img = transform(img).unsqueeze(0)

# Inference
with torch.no_grad():
    output = model(input_img)
    _, predicted = torch.max(output, 1)

# Labels for classification
labels = {0: 'Pneumonia', 1: 'Normal'}
print(f'Predicted label: {labels[predicted.item()]}')
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Examples
Inference API (serverless) is not available, repository is disabled.