|
--- |
|
language: en |
|
tags: |
|
- vision |
|
- image-classification |
|
- medical-imaging |
|
- tumor-classification |
|
license: apache-2.0 |
|
base_model: google/vit-base-patch16-224 |
|
model-index: |
|
- name: vit_tumor_classifier |
|
results: |
|
- task: |
|
name: Image Classification |
|
type: binary-classification |
|
metrics: |
|
- name: Accuracy |
|
type: accuracy |
|
value: 0.85 |
|
- name: F1 Score |
|
type: f1 |
|
value: 0.84 |
|
--- |
|
|
|
# Vision Transformer for Tumor Classification |
|
|
|
This model is a fine-tuned version of [google/vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224) for binary tumor classification in medical images. |
|
|
|
## Model Details |
|
|
|
- **Model Type:** Vision Transformer (ViT) |
|
- **Base Model:** google/vit-base-patch16-224 |
|
- **Task:** Binary Image Classification |
|
- **Training Data:** Medical image dataset with tumor/non-tumor annotations |
|
- **Input:** Medical images (224x224 pixels) |
|
- **Output:** Binary classification (tumor/non-tumor) |
|
- **Model Size:** 85.8M parameters |
|
- **Framework:** PyTorch |
|
- **License:** Apache 2.0 |
|
|
|
## Intended Use |
|
|
|
This model is designed for tumor classification in medical imaging. It should be used as part of a larger medical diagnostic system and not as a standalone diagnostic tool. |
|
|
|
## Usage |
|
|
|
```python |
|
from transformers import AutoImageProcessor, AutoModelForImageClassification |
|
from PIL import Image |
|
|
|
# Load model and processor |
|
processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_classifier") |
|
model = AutoModelForImageClassification.from_pretrained("SIATCN/vit_tumor_classifier") |
|
|
|
# Load and process image |
|
image = Image.open("path_to_your_image.jpg") |
|
inputs = processor(image, return_tensors="pt") |
|
|
|
# Make prediction |
|
outputs = model(**inputs) |
|
predictions = outputs.logits.softmax(dim=-1) |
|
predicted_label = predictions.argmax().item() |
|
confidence = predictions[0][predicted_label].item() |
|
|
|
# Get class name |
|
class_names = ["non-tumor", "tumor"] |
|
print(f"Predicted: {class_names[predicted_label]} (confidence: {confidence:.2f})") |