π§ Brain Tumor Classification Using Vision Transformer (ViT)
This repository contains a fine-tuned Vision Transformer (ViT) model trained on a large collection of MRI scans for brain tumor classification. The model classifies MRI images into one of three categories:
- Glioma
- Meningioma
- Tumor (General)
The dataset used includes over 75,000 color-enhanced MRI images, making this model highly capable for research and educational applications in brain tumor detection.
π Dataset Information
- Original Dataset Name: Brain Cancer - MRI dataset
- Author: Rahman, Md Mizanur (2024)
- Hosted on: Mendeley Data
- DOI: 10.17632/mk56jw9rns.1
- Kaggle Rehost (Colorized): Shuvo Kumar Basak on Kaggle
Note: This dataset is publicly available for non-commercial research use. The model does not include the dataset itself.
π§ Model Architecture
- Model Type: Vision Transformer (ViT-B/16)
- Framework: PyTorch + timm
- Input Shape: 224x224 RGB
- Number of Classes: 3
- Loss Function: CrossEntropyLoss
- Optimizer: AdamW
π Training Pipeline Summary
Image Preprocessing:
- Resize to 224x224
- Normalization using ImageNet stats
- Augmentations: Horizontal/Vertical Flip, ShiftScaleRotate, BrightnessContrast, etc.
DataLoader:
- Stratified Split (Train/Val/Test)
- PyTorch
Dataset
andDataLoader
classes
Model:
- Loaded ViT using
timm.create_model('vit_base_patch16_224', pretrained=True)
- Modified the classifier head to match 3 output classes
- Loaded ViT using
Training:
- Trained using mixed precision (
torch.cuda.amp
) - Tracked using
tqdm
- Trained using mixed precision (
Saving:
- Model saved as
pytorch_model.bin
- Configuration saved as
config.json
- Model saved as
π Intended Use
This model is designed for:
- Educational purposes (deep learning and medical imaging)
- Research in brain tumor classification using transformers
- Demonstrating the power of ViT on colorized medical datasets
β οΈ Not intended for clinical use or deployment without regulatory approval and further validation.
π Inference Example (Python)
from timm import create_model
import torch
from torchvision import transforms
from PIL import Image
# Load model
model = create_model('vit_base_patch16_224', pretrained=False, num_classes=3)
model.load_state_dict(torch.load("pytorch_model.bin"))
model.eval()
# Transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5]*3, std=[0.5]*3),
])
# Inference
image = Image.open("example_mri.jpg").convert("RGB")
tensor = transform(image).unsqueeze(0)
output = model(tensor)
pred = torch.argmax(output, dim=1)
print("Predicted class:", pred.item())
- Downloads last month
- 13
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
π
Ask for provider support