diff --git "a/Fine_tuning_MedGemma.ipynb" "b/Fine_tuning_MedGemma.ipynb" new file mode 100644--- /dev/null +++ "b/Fine_tuning_MedGemma.ipynb" @@ -0,0 +1,1240 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "445e7bc9-e783-48e6-9f09-3aa295c1d216", + "metadata": {}, + "source": [ + "## Installing Python Packages" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "1124aadd-600c-4ff3-88b8-db724d3a8071", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", + "\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.1.1\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n" + ] + } + ], + "source": [ + "! pip install --upgrade --quiet transformers bitsandbytes datasets evaluate peft trl scikit-learn kaggle" + ] + }, + { + "cell_type": "markdown", + "id": "ed3b4bee-6824-4c16-ac55-30285152a199", + "metadata": {}, + "source": [ + "## Loading and Processing the Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "621190d1-2d43-410d-9bd4-b586918bab81", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset URL: https://www.kaggle.com/datasets/orvile/brain-cancer-mri-dataset\n", + "License(s): CC-BY-SA-4.0\n", + "Downloading brain-cancer-mri-dataset.zip to /workspace\n", + " 90%|████████████████████████████████████▉ | 130M/144M [00:00<00:00, 231MB/s]\n", + "100%|█████████████████████████████████████████| 144M/144M [00:01<00:00, 108MB/s]\n" + ] + } + ], + "source": [ + "!kaggle datasets download -d orvile/brain-cancer-mri-dataset --unzip" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "66e9b3ee-7220-4a6c-b573-895cdb9438c5", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "dc0b52fa43b947aaa860b5819ac93642", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Resolving data files: 0%| | 0/6056 [00:00" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Display the first image (as a PIL object)\n", + "data[\"train\"][0][\"image\"]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "6fbd1185-7384-4211-8688-91d51a839f51", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1\n" + ] + } + ], + "source": [ + "# Display the corresponding label\n", + "print(data[\"train\"][0][\"label\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "85470739-3559-406b-af67-07a6d2ccbc27", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Detected classes: ['brain_glioma', 'brain_menin', 'brain_tumor']\n" + ] + } + ], + "source": [ + "BRAIN_CANCER_CLASSES = data[\"train\"].features[\"label\"].names\n", + "print(\"Detected classes:\", BRAIN_CANCER_CLASSES)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "e8048ede-ccc5-4b88-b8bf-e3e5d4c2ee1d", + "metadata": {}, + "outputs": [], + "source": [ + "BRAIN_CANCER_CLASSES = ['A: brain glioma', 'B: brain menin', 'C: brain tumor']\n", + "\n", + "options = \"\\n\".join(BRAIN_CANCER_CLASSES)\n", + "PROMPT = f\"What is the most likely type of brain cancer shown in the MRI image?\\n{options}\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "20e7b335-2b7c-40d9-96e8-c7c7d568cd3f", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cb322285f490480ca5a9343c48996eeb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Map: 0%| | 0/4844 [00:00 dict[str, any]:\n", + " example[\"messages\"] = [\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": [\n", + " {\n", + " \"type\": \"image\",\n", + " },\n", + " {\n", + " \"type\": \"text\",\n", + " \"text\": PROMPT,\n", + " },\n", + " ],\n", + " },\n", + " {\n", + " \"role\": \"assistant\",\n", + " \"content\": [\n", + " {\n", + " \"type\": \"text\",\n", + " \"text\": BRAIN_CANCER_CLASSES[example[\"label\"]],\n", + " },\n", + " ],\n", + " },\n", + " ]\n", + " return example\n", + "\n", + "# Apply the formatting to the dataset\n", + "formatted_data = data.map(format_data)\n", + "\n", + "# Display a sample formatted data point\n", + "formatted_data[\"train\"][0][\"messages\"]\n" + ] + }, + { + "cell_type": "markdown", + "id": "892b4731-e9f5-4ba2-bd23-ff6cb3ac7c50", + "metadata": {}, + "source": [ + "## Loading the Model and Tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "e814414a-d800-452b-b664-9dcd7182eb21", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.\n" + ] + } + ], + "source": [ + "from huggingface_hub import login\n", + "import os\n", + "\n", + "hf_token = os.environ.get(\"HF_TOKEN\")\n", + "login(hf_token)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c3b40bcf-6e10-458c-9712-b8425457dcf7", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8e0e87b0141344b099a77009b004f932", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00\n", + " \n", + " \n", + " [76/76 1:07:40, Epoch 1/1]\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining LossValidation Loss
831.4993001.393780
165.9162000.373517
241.4357000.088616
320.5327000.060583
400.4176000.046456
480.3557000.039753
560.3161000.043567
640.3221000.037666
720.2875000.035154

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "TrainOutput(global_step=76, training_loss=4.338119332727633, metrics={'train_runtime': 4112.8912, 'train_samples_per_second': 1.178, 'train_steps_per_second': 0.018, 'total_flos': 3.853875156358656e+16, 'train_loss': 4.338119332727633})" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "80e25489-48cd-4152-b762-2c58572f9489", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6237933d3e3449f291a6dfab32e4dc4e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Uploading...: 0%| | 0.00/2.87G [00:00 dict[str, any]:\n", + " example[\"messages\"] = [\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": [\n", + " {\n", + " \"type\": \"image\",\n", + " },\n", + " {\n", + " \"type\": \"text\",\n", + " \"text\": PROMPT,\n", + " },\n", + " ],\n", + " },\n", + " ]\n", + " return example" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "806d411a-32aa-41c0-bbbd-a25a6a67c844", + "metadata": {}, + "outputs": [], + "source": [ + "test_data = data[\"validation\"]\n", + "test_data = test_data.map(format_test_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "badca756-d6f7-4bc0-9009-5fd2bb483666", + "metadata": {}, + "outputs": [], + "source": [ + "import evaluate\n", + "\n", + "accuracy_metric = evaluate.load(\"accuracy\")\n", + "f1_metric = evaluate.load(\"f1\")\n", + "\n", + "# Ground-truth labels\n", + "REFERENCES = test_data[\"label\"]\n", + "\n", + "\n", + "def compute_metrics(predictions: list[int]) -> dict[str, float]:\n", + " metrics = {}\n", + " metrics.update(\n", + " accuracy_metric.compute(\n", + " predictions=predictions,\n", + " references=REFERENCES,\n", + " )\n", + " )\n", + " metrics.update(\n", + " f1_metric.compute(\n", + " predictions=predictions,\n", + " references=REFERENCES,\n", + " average=\"weighted\",\n", + " )\n", + " )\n", + " return metrics\n" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "70b28c91-c377-4667-b045-27fd2ecbad6d", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import ClassLabel\n", + "\n", + "test_data = test_data.cast_column(\n", + " \"label\",\n", + " ClassLabel(names=BRAIN_CANCER_CLASSES)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "01ad709b-5b7d-4355-bc2e-c03fa87c76bc", + "metadata": {}, + "outputs": [], + "source": [ + "LABEL_FEATURE = test_data.features[\"label\"]\n", + "\n", + "ALT_LABELS = dict([\n", + " (label, f\"({label.replace(': ', ') ')}\") for label in BRAIN_CANCER_CLASSES\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "849c8293-14b2-4b26-a3ea-5749a3d844e2", + "metadata": {}, + "outputs": [], + "source": [ + "def postprocess(prediction, do_full_match: bool = False) -> int:\n", + " if isinstance(prediction, str):\n", + " response_text = prediction\n", + " else:\n", + " response_text = prediction[0][\"generated_text\"]\n", + "\n", + " if do_full_match:\n", + " return LABEL_FEATURE.str2int(response_text)\n", + "\n", + " for label in BRAIN_CANCER_CLASSES:\n", + " # accept canonical or alternative wording\n", + " if label in response_text or ALT_LABELS[label] in response_text:\n", + " return LABEL_FEATURE.str2int(label)\n", + "\n", + " return -1\n" + ] + }, + { + "cell_type": "markdown", + "id": "2e1a23f4-beff-4303-8630-6320512c46c8", + "metadata": {}, + "source": [ + "### Model performance on the base model" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "ff8164fc-a5d8-4a78-9d68-93542a4c3f40", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "372a690e2ad24ed2926befbcb7c28573", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00 str:\n", + " inputs = processor(text=prompt, images=image, return_tensors=\"pt\").to(\n", + " device, dtype=dtype\n", + " )\n", + " plen = inputs[\"input_ids\"].shape[-1]\n", + " with torch.inference_mode():\n", + " ids = model.generate(\n", + " **inputs,\n", + " disable_compile=disable_compile,\n", + " **gen_kwargs\n", + " )\n", + " return processor.decode(ids[0, plen:], skip_special_tokens=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "50064a57-57b4-493e-9c33-6fb1a0bb99d5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model answer: Based on the MRI image, the most likely type of brain cancer is **A: brain glioma**.\n", + "\n", + "Here's why:\n", + "\n", + "* **Gliomas** are a common type of brain tumor\n" + ] + } + ], + "source": [ + "idx = 10\n", + "chat = test_data[\"messages\"][idx]\n", + "prompt = processor.apply_chat_template(\n", + " chat,\n", + " add_generation_prompt=True,\n", + " tokenize=False\n", + " )\n", + "\n", + "# run the one-sample helper\n", + "answer = predict_one(\n", + " prompt = prompt,\n", + " image = test_data[\"image\"][idx],\n", + " model = model,\n", + " processor= processor,\n", + " max_new_tokens = 40 \n", + ")\n", + "\n", + "print(\"Model answer:\", answer)" + ] + }, + { + "cell_type": "markdown", + "id": "802654b8-00e6-450d-a030-4fc5eca0fab2", + "metadata": {}, + "source": [ + "### Model performance on the fine-tuned model" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "8f8961c1-8db0-4e28-a7ce-92c2ca524141", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "328c05c9ec8746759dc8ac0c98901d66", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00,\n", + " 'label': 2}" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data[\"validation\"][10]" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}