🧠 Brain Tumor Classification Using Vision Transformer (ViT)

This repository contains a fine-tuned Vision Transformer (ViT) model trained on a large collection of MRI scans for brain tumor classification. The model classifies MRI images into one of three categories:

  • Glioma
  • Meningioma
  • Tumor (General)

The dataset used includes over 75,000 color-enhanced MRI images, making this model highly capable for research and educational applications in brain tumor detection.


πŸ“Š Dataset Information

Note: This dataset is publicly available for non-commercial research use. The model does not include the dataset itself.


🧠 Model Architecture

  • Model Type: Vision Transformer (ViT-B/16)
  • Framework: PyTorch + timm
  • Input Shape: 224x224 RGB
  • Number of Classes: 3
  • Loss Function: CrossEntropyLoss
  • Optimizer: AdamW

🏁 Training Pipeline Summary

  1. Image Preprocessing:

    • Resize to 224x224
    • Normalization using ImageNet stats
    • Augmentations: Horizontal/Vertical Flip, ShiftScaleRotate, BrightnessContrast, etc.
  2. DataLoader:

    • Stratified Split (Train/Val/Test)
    • PyTorch Dataset and DataLoader classes
  3. Model:

    • Loaded ViT using timm.create_model('vit_base_patch16_224', pretrained=True)
    • Modified the classifier head to match 3 output classes
  4. Training:

    • Trained using mixed precision (torch.cuda.amp)
    • Tracked using tqdm
  5. Saving:

    • Model saved as pytorch_model.bin
    • Configuration saved as config.json

πŸ” Intended Use

This model is designed for:

  • Educational purposes (deep learning and medical imaging)
  • Research in brain tumor classification using transformers
  • Demonstrating the power of ViT on colorized medical datasets

⚠️ Not intended for clinical use or deployment without regulatory approval and further validation.


πŸš€ Inference Example (Python)

from timm import create_model
import torch
from torchvision import transforms
from PIL import Image

# Load model
model = create_model('vit_base_patch16_224', pretrained=False, num_classes=3)
model.load_state_dict(torch.load("pytorch_model.bin"))
model.eval()

# Transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3),
])

# Inference
image = Image.open("example_mri.jpg").convert("RGB")
tensor = transform(image).unsqueeze(0)
output = model(tensor)
pred = torch.argmax(output, dim=1)
print("Predicted class:", pred.item())
Downloads last month
13
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support