{ "cells": [ { "cell_type": "markdown", "id": "445e7bc9-e783-48e6-9f09-3aa295c1d216", "metadata": {}, "source": [ "## Installing Python Packages" ] }, { "cell_type": "code", "execution_count": 1, "id": "1124aadd-600c-4ff3-88b8-db724d3a8071", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", "\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.1.1\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n" ] } ], "source": [ "! pip install --upgrade --quiet transformers bitsandbytes datasets evaluate peft trl scikit-learn kaggle" ] }, { "cell_type": "markdown", "id": "ed3b4bee-6824-4c16-ac55-30285152a199", "metadata": {}, "source": [ "## Loading and Processing the Dataset" ] }, { "cell_type": "code", "execution_count": 2, "id": "621190d1-2d43-410d-9bd4-b586918bab81", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset URL: https://www.kaggle.com/datasets/orvile/brain-cancer-mri-dataset\n", "License(s): CC-BY-SA-4.0\n", "Downloading brain-cancer-mri-dataset.zip to /workspace\n", " 90%|████████████████████████████████████▉ | 130M/144M [00:00<00:00, 231MB/s]\n", "100%|█████████████████████████████████████████| 144M/144M [00:01<00:00, 108MB/s]\n" ] } ], "source": [ "!kaggle datasets download -d orvile/brain-cancer-mri-dataset --unzip" ] }, { "cell_type": "code", "execution_count": 3, "id": "66e9b3ee-7220-4a6c-b573-895cdb9438c5", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "dc0b52fa43b947aaa860b5819ac93642", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Resolving data files: 0%| | 0/6056 [00:00" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Display the first image (as a PIL object)\n", "data[\"train\"][0][\"image\"]\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "6fbd1185-7384-4211-8688-91d51a839f51", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1\n" ] } ], "source": [ "# Display the corresponding label\n", "print(data[\"train\"][0][\"label\"])" ] }, { "cell_type": "code", "execution_count": 6, "id": "85470739-3559-406b-af67-07a6d2ccbc27", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Detected classes: ['brain_glioma', 'brain_menin', 'brain_tumor']\n" ] } ], "source": [ "BRAIN_CANCER_CLASSES = data[\"train\"].features[\"label\"].names\n", "print(\"Detected classes:\", BRAIN_CANCER_CLASSES)" ] }, { "cell_type": "code", "execution_count": 7, "id": "e8048ede-ccc5-4b88-b8bf-e3e5d4c2ee1d", "metadata": {}, "outputs": [], "source": [ "BRAIN_CANCER_CLASSES = ['A: brain glioma', 'B: brain menin', 'C: brain tumor']\n", "\n", "options = \"\\n\".join(BRAIN_CANCER_CLASSES)\n", "PROMPT = f\"What is the most likely type of brain cancer shown in the MRI image?\\n{options}\"\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "20e7b335-2b7c-40d9-96e8-c7c7d568cd3f", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cb322285f490480ca5a9343c48996eeb", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/4844 [00:00 dict[str, any]:\n", " example[\"messages\"] = [\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\n", " \"type\": \"image\",\n", " },\n", " {\n", " \"type\": \"text\",\n", " \"text\": PROMPT,\n", " },\n", " ],\n", " },\n", " {\n", " \"role\": \"assistant\",\n", " \"content\": [\n", " {\n", " \"type\": \"text\",\n", " \"text\": BRAIN_CANCER_CLASSES[example[\"label\"]],\n", " },\n", " ],\n", " },\n", " ]\n", " return example\n", "\n", "# Apply the formatting to the dataset\n", "formatted_data = data.map(format_data)\n", "\n", "# Display a sample formatted data point\n", "formatted_data[\"train\"][0][\"messages\"]\n" ] }, { "cell_type": "markdown", "id": "892b4731-e9f5-4ba2-bd23-ff6cb3ac7c50", "metadata": {}, "source": [ "## Loading the Model and Tokenizer" ] }, { "cell_type": "code", "execution_count": 9, "id": "e814414a-d800-452b-b664-9dcd7182eb21", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.\n" ] } ], "source": [ "from huggingface_hub import login\n", "import os\n", "\n", "hf_token = os.environ.get(\"HF_TOKEN\")\n", "login(hf_token)\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "c3b40bcf-6e10-458c-9712-b8425457dcf7", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8e0e87b0141344b099a77009b004f932", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00\n", " \n", " \n", " [76/76 1:07:40, Epoch 1/1]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining LossValidation Loss
831.4993001.393780
165.9162000.373517
241.4357000.088616
320.5327000.060583
400.4176000.046456
480.3557000.039753
560.3161000.043567
640.3221000.037666
720.2875000.035154

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=76, training_loss=4.338119332727633, metrics={'train_runtime': 4112.8912, 'train_samples_per_second': 1.178, 'train_steps_per_second': 0.018, 'total_flos': 3.853875156358656e+16, 'train_loss': 4.338119332727633})" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": 16, "id": "80e25489-48cd-4152-b762-2c58572f9489", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6237933d3e3449f291a6dfab32e4dc4e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Uploading...: 0%| | 0.00/2.87G [00:00 dict[str, any]:\n", " example[\"messages\"] = [\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\n", " \"type\": \"image\",\n", " },\n", " {\n", " \"type\": \"text\",\n", " \"text\": PROMPT,\n", " },\n", " ],\n", " },\n", " ]\n", " return example" ] }, { "cell_type": "code", "execution_count": 32, "id": "806d411a-32aa-41c0-bbbd-a25a6a67c844", "metadata": {}, "outputs": [], "source": [ "test_data = data[\"validation\"]\n", "test_data = test_data.map(format_test_data)" ] }, { "cell_type": "code", "execution_count": 33, "id": "badca756-d6f7-4bc0-9009-5fd2bb483666", "metadata": {}, "outputs": [], "source": [ "import evaluate\n", "\n", "accuracy_metric = evaluate.load(\"accuracy\")\n", "f1_metric = evaluate.load(\"f1\")\n", "\n", "# Ground-truth labels\n", "REFERENCES = test_data[\"label\"]\n", "\n", "\n", "def compute_metrics(predictions: list[int]) -> dict[str, float]:\n", " metrics = {}\n", " metrics.update(\n", " accuracy_metric.compute(\n", " predictions=predictions,\n", " references=REFERENCES,\n", " )\n", " )\n", " metrics.update(\n", " f1_metric.compute(\n", " predictions=predictions,\n", " references=REFERENCES,\n", " average=\"weighted\",\n", " )\n", " )\n", " return metrics\n" ] }, { "cell_type": "code", "execution_count": 34, "id": "70b28c91-c377-4667-b045-27fd2ecbad6d", "metadata": {}, "outputs": [], "source": [ "from datasets import ClassLabel\n", "\n", "test_data = test_data.cast_column(\n", " \"label\",\n", " ClassLabel(names=BRAIN_CANCER_CLASSES)\n", ")" ] }, { "cell_type": "code", "execution_count": 35, "id": "01ad709b-5b7d-4355-bc2e-c03fa87c76bc", "metadata": {}, "outputs": [], "source": [ "LABEL_FEATURE = test_data.features[\"label\"]\n", "\n", "ALT_LABELS = dict([\n", " (label, f\"({label.replace(': ', ') ')}\") for label in BRAIN_CANCER_CLASSES\n", "])" ] }, { "cell_type": "code", "execution_count": 36, "id": "849c8293-14b2-4b26-a3ea-5749a3d844e2", "metadata": {}, "outputs": [], "source": [ "def postprocess(prediction, do_full_match: bool = False) -> int:\n", " if isinstance(prediction, str):\n", " response_text = prediction\n", " else:\n", " response_text = prediction[0][\"generated_text\"]\n", "\n", " if do_full_match:\n", " return LABEL_FEATURE.str2int(response_text)\n", "\n", " for label in BRAIN_CANCER_CLASSES:\n", " # accept canonical or alternative wording\n", " if label in response_text or ALT_LABELS[label] in response_text:\n", " return LABEL_FEATURE.str2int(label)\n", "\n", " return -1\n" ] }, { "cell_type": "markdown", "id": "2e1a23f4-beff-4303-8630-6320512c46c8", "metadata": {}, "source": [ "### Model performance on the base model" ] }, { "cell_type": "code", "execution_count": 37, "id": "ff8164fc-a5d8-4a78-9d68-93542a4c3f40", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "372a690e2ad24ed2926befbcb7c28573", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00 str:\n", " inputs = processor(text=prompt, images=image, return_tensors=\"pt\").to(\n", " device, dtype=dtype\n", " )\n", " plen = inputs[\"input_ids\"].shape[-1]\n", " with torch.inference_mode():\n", " ids = model.generate(\n", " **inputs,\n", " disable_compile=disable_compile,\n", " **gen_kwargs\n", " )\n", " return processor.decode(ids[0, plen:], skip_special_tokens=True)\n" ] }, { "cell_type": "code", "execution_count": 45, "id": "50064a57-57b4-493e-9c33-6fb1a0bb99d5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model answer: Based on the MRI image, the most likely type of brain cancer is **A: brain glioma**.\n", "\n", "Here's why:\n", "\n", "* **Gliomas** are a common type of brain tumor\n" ] } ], "source": [ "idx = 10\n", "chat = test_data[\"messages\"][idx]\n", "prompt = processor.apply_chat_template(\n", " chat,\n", " add_generation_prompt=True,\n", " tokenize=False\n", " )\n", "\n", "# run the one-sample helper\n", "answer = predict_one(\n", " prompt = prompt,\n", " image = test_data[\"image\"][idx],\n", " model = model,\n", " processor= processor,\n", " max_new_tokens = 40 \n", ")\n", "\n", "print(\"Model answer:\", answer)" ] }, { "cell_type": "markdown", "id": "802654b8-00e6-450d-a030-4fc5eca0fab2", "metadata": {}, "source": [ "### Model performance on the fine-tuned model" ] }, { "cell_type": "code", "execution_count": 47, "id": "8f8961c1-8db0-4e28-a7ce-92c2ca524141", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "328c05c9ec8746759dc8ac0c98901d66", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00,\n", " 'label': 2}" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data[\"validation\"][10]" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 5 }