|
--- |
|
base_model: google/medgemma-4b-it |
|
library_name: transformers |
|
model_name: medgemma-brain-cancer |
|
tags: |
|
- generated_from_trainer |
|
- trl |
|
- sft |
|
- medical |
|
- mri |
|
- brain_tumor |
|
licence: license |
|
license: apache-2.0 |
|
language: |
|
- en |
|
pipeline_tag: image-text-to-text |
|
metrics: |
|
- accuracy |
|
- f1 |
|
model-index: |
|
- name: finetuned-model |
|
results: |
|
- task: |
|
type: image-text-to-text |
|
dataset: |
|
name: orvile/brain-cancer-mri-dataset |
|
type: image-text-to-text |
|
metrics: |
|
- name: accuracy |
|
type: accuracy |
|
value: 0.8927392739273927 |
|
- name: f1 |
|
type: f1 |
|
value: 0.892641793935792 |
|
--- |
|
|
|
# π§ MedGemma-Brain-Cancer |
|
|
|
`medgemma-brain-cancer` is a fine-tuned version of [google/medgemma-4b-it](https://huggingface.co/google/medgemma-4b-it), trained specifically for brain tumor diagnosis and classification from MRI scans. This model leverages vision-language learning for enhanced medical imaging interpretation. |
|
|
|
## π¬ Model Details |
|
|
|
* **Base Model**: [google/medgemma-4b-it](https://huggingface.co/google/medgemma-4b-it) |
|
* **Dataset**: [orvile/brain-cancer-mri-dataset](https://www.kaggle.com/datasets/orvile/brain-cancer-mri-dataset) |
|
* **Fine-tuning Approach**: Supervised fine-tuning (SFT) using [Transformers Reinforcement Learning (TRL)](https://github.com/huggingface/trl) |
|
* **Task**: Brain tumor classification from MRI images |
|
* **Pipeline Tag**: `image-text-to-text` |
|
* **Accuracy Improvement**: |
|
|
|
* Base model accuracy: **33%** |
|
* Fine-tuned model accuracy: **89%** |
|
|
|
## π Results & Notebook |
|
|
|
Explore the training pipeline, evaluation results, and experiments in the notebook: |
|
|
|
π **[Fine\_tuning\_MedGemma.ipynb](https://huggingface.co/kingabzpro/medgemma-brain-cancer/blob/main/Fine_tuning_MedGemma.ipynb)** |
|
|
|
## π Inference Example |
|
|
|
```python |
|
# pip install transformers accelerate |
|
from transformers import AutoProcessor, AutoModelForImageTextToText |
|
from PIL import Image |
|
import requests |
|
import torch |
|
|
|
model_id = "kingabzpro/medgemma-brain-cancer" |
|
|
|
model = AutoModelForImageTextToText.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
) |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
# Example Brain MRI image β attribution: Orvile, via Kaggle dataset |
|
image_url = "https://storage.googleapis.com/kagglesdsdata/datasets/7006196/11239552/Brain_Cancer%20raw%20MRI%20data/Brain_Cancer/brain_menin/brain_menin_0002.jpg?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=databundle-worker-v2%40kaggle-161607.iam.gserviceaccount.com%2F20250527%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20250527T102729Z&X-Goog-Expires=345600&X-Goog-SignedHeaders=host&X-Goog-Signature=4b83c95f9776b7f1f1a9d7184002f2f3c33b8d9c5fcfc3326b5f7bb9fa380910cd22534e28224a0b576abdd14f3ba2ebd0ef9ecca6ef8bd3fb1ba0aa048fe8a5cee77f06bebe91d9954793851a259a72f1c204e930e1f6957113d52a199ba7fa7d36841c943df7fcfbc599d76eb1e04999cee1e9a9d02afcc853418a7306da3e95b9f13ac16187e3d85e6dca81ffce7a6c71eee966a32166f0e6cd6f751e62883864f4d27401e0dc7de98645ca5ead9e9f5c6e989ca62448a46076885e4422acbe21b579f27616732b527f234ef9e172455777e550bc558ffd28107cc354057667befdc5c8e87475eaf7af4507ee6012d8b58130c62cf0171b86b4f8596c7677" |
|
image = Image.open(requests.get(image_url, headers={"User-Agent": "example"}, stream=True).raw) |
|
|
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "image", "text": None, "image": image}, |
|
{"type": "text", "text": "What is the most likely type of brain cancer shown in the MRI image?\nA: brain glioma\nB: brain menin\nC: brain tumor"} |
|
] |
|
} |
|
] |
|
|
|
inputs = processor.apply_chat_template( |
|
messages, add_generation_prompt=True, tokenize=True, |
|
return_dict=True, return_tensors="pt" |
|
).to(model.device, dtype=torch.bfloat16) |
|
|
|
input_len = inputs["input_ids"].shape[-1] |
|
|
|
with torch.inference_mode(): |
|
generation = model.generate(**inputs, max_new_tokens=20, do_sample=False) |
|
generation = generation[0][input_len:] |
|
|
|
decoded = processor.decode(generation, skip_special_tokens=True) |
|
print(decoded) |
|
``` |
|
|
|
**Expected Output:** |
|
|
|
```text |
|
B: brain menin |
|
``` |
|
|
|
## π§ͺ Intended Use |
|
|
|
This model is intended for research and educational purposes related to medical imaging, specifically brain tumor classification. It is **not** a certified diagnostic tool and should not be used in clinical decision-making without further validation. |
|
|
|
## π·οΈ Tags |
|
|
|
* `medical` |
|
* `brain_tumor` |
|
* `mri` |
|
* `trl` |
|
* `sft` |
|
|
|
## π License |
|
|
|
Apache 2.0 License |