"
+ ],
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [3/3 00:14, Epoch 3/3]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Step | \n",
+ " Training Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 1.434700 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 1.434300 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 1.429000 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "TrainOutput(global_step=3, training_loss=1.432644248008728, metrics={'train_runtime': 18.9818, 'train_samples_per_second': 18.966, 'train_steps_per_second': 0.158, 'total_flos': 12741654491136.0, 'train_loss': 1.432644248008728})"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 69
+ }
+ ],
+ "source": [
+ "trainer.train() # after a few iterations"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "OoonXeFSjGqO"
+ },
+ "source": [
+ "### saving model fp32"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "WZu3e1ITu_ZA",
+ "outputId": "128003eb-668a-40f8-a0fc-ec443f559c96"
+ },
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "('gemma-genz-270M-peft/tokenizer_config.json',\n",
+ " 'gemma-genz-270M-peft/special_tokens_map.json',\n",
+ " 'gemma-genz-270M-peft/chat_template.jinja',\n",
+ " 'gemma-genz-270M-peft/tokenizer.model',\n",
+ " 'gemma-genz-270M-peft/added_tokens.json',\n",
+ " 'gemma-genz-270M-peft/tokenizer.json')"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 72
+ }
+ ],
+ "source": [
+ "trainer.model.save_pretrained(\"gemma-genz-270M-peft\")\n",
+ "tokenizer.save_pretrained(\"gemma-genz-270M-peft\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 36
+ },
+ "id": "dZ-ChN-hiuk5",
+ "outputId": "8d862716-9cb8-416f-d5e3-7d6cb1310cd9"
+ },
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "'/content/gemma-genz-270M-peft.zip'"
+ ],
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "string"
+ }
+ },
+ "metadata": {},
+ "execution_count": 73
+ }
+ ],
+ "source": [
+ "# saving to local\n",
+ "import shutil\n",
+ "\n",
+ "# Path to your saved model\n",
+ "model_folder = \"gemma-genz-270M-peft\"\n",
+ "zip_name = \"gemma-genz-270M-peft.zip\"\n",
+ "\n",
+ "# Create zip\n",
+ "shutil.make_archive(\"gemma-genz-270M-peft\", 'zip', model_folder)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 17
+ },
+ "id": "df3yGr0giuiV",
+ "outputId": "6e1ee009-56d8-44e3-a09b-3c9711339e8a"
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ ""
+ ],
+ "application/javascript": [
+ "\n",
+ " async function download(id, filename, size) {\n",
+ " if (!google.colab.kernel.accessAllowed) {\n",
+ " return;\n",
+ " }\n",
+ " const div = document.createElement('div');\n",
+ " const label = document.createElement('label');\n",
+ " label.textContent = `Downloading \"${filename}\": `;\n",
+ " div.appendChild(label);\n",
+ " const progress = document.createElement('progress');\n",
+ " progress.max = size;\n",
+ " div.appendChild(progress);\n",
+ " document.body.appendChild(div);\n",
+ "\n",
+ " const buffers = [];\n",
+ " let downloaded = 0;\n",
+ "\n",
+ " const channel = await google.colab.kernel.comms.open(id);\n",
+ " // Send a message to notify the kernel that we're ready.\n",
+ " channel.send({})\n",
+ "\n",
+ " for await (const message of channel.messages) {\n",
+ " // Send a message to notify the kernel that we're ready.\n",
+ " channel.send({})\n",
+ " if (message.buffers) {\n",
+ " for (const buffer of message.buffers) {\n",
+ " buffers.push(buffer);\n",
+ " downloaded += buffer.byteLength;\n",
+ " progress.value = downloaded;\n",
+ " }\n",
+ " }\n",
+ " }\n",
+ " const blob = new Blob(buffers, {type: 'application/binary'});\n",
+ " const a = document.createElement('a');\n",
+ " a.href = window.URL.createObjectURL(blob);\n",
+ " a.download = filename;\n",
+ " div.appendChild(a);\n",
+ " a.click();\n",
+ " div.remove();\n",
+ " }\n",
+ " "
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ ""
+ ],
+ "application/javascript": [
+ "download(\"download_4cd2d03d-3920-428a-b87b-e5854dea11d0\", \"gemma-genz-270M-peft.zip\", 64074828)"
+ ]
+ },
+ "metadata": {}
+ }
+ ],
+ "source": [
+ "from google.colab import files\n",
+ "\n",
+ "files.download(zip_name)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "4QRB-MGhjRbV"
+ },
+ "source": [
+ "### Inference fp32"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import torch\n",
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline\n",
+ "from peft import PeftModel\n",
+ "\n",
+ "device = 0 if torch.cuda.is_available() else -1 # pipeline uses int device\n",
+ "\n",
+ "base_model_name = \"google/gemma-3-270m-it\"\n",
+ "base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float32)\n",
+ "\n",
+ "peft_model_path = \"/content/gemma-genz-270M-peft\"\n",
+ "model = PeftModel.from_pretrained(base_model, peft_model_path)\n",
+ "model.eval()\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(peft_model_path)\n",
+ "\n",
+ "text_gen_pipeline = pipeline(\n",
+ " \"text-generation\",\n",
+ " model=model,\n",
+ " tokenizer=tokenizer,\n",
+ " device=device\n",
+ ")\n",
+ "\n",
+ "def generate_text_pipeline(prompts, max_new_tokens=100, temperature=1.0, top_k=50, top_p=0.95, do_sample=True):\n",
+ " \"\"\"\n",
+ " Generate text using the preloaded PEFT LoRA model via Hugging Face pipeline.\n",
+ "\n",
+ " Args:\n",
+ " prompts (str or list[str]): Single prompt or list of prompts.\n",
+ " max_new_tokens (int): Maximum tokens to generate beyond input.\n",
+ " temperature (float): Sampling temperature.\n",
+ " top_k (int): Top-k sampling.\n",
+ " top_p (float): Top-p sampling (nucleus).\n",
+ " do_sample (bool): Whether to sample or use greedy decoding.\n",
+ "\n",
+ " Returns:\n",
+ " list[str]: Generated text(s).\n",
+ " \"\"\"\n",
+ " if isinstance(prompts, str):\n",
+ " prompts = [prompts]\n",
+ "\n",
+ " outputs = text_gen_pipeline(\n",
+ " prompts,\n",
+ " max_new_tokens=max_new_tokens,\n",
+ " temperature=temperature,\n",
+ " top_k=top_k,\n",
+ " top_p=top_p,\n",
+ " do_sample=do_sample\n",
+ " )\n",
+ "\n",
+ " return outputs"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "6nkhf7LEW5HH",
+ "outputId": "76407587-55b7-4e5d-bfbc-c9c21ded24cc"
+ },
+ "execution_count": 47,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "Device set to use cuda:0\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 51,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "cl0M5jriwJtX",
+ "outputId": "c0836e9c-6d33-4c32-e384-c82ec68669ee"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "[[{'generated_text': 'job market is dead bro! #opportunityseek #careergoals #adultingNOW #worklifebalance\\nHarhar ke liye job search kar raha hoon, identity crisis mein.'}]]\n"
+ ]
+ }
+ ],
+ "source": [
+ "prompt = \"job market is dead bro!\"\n",
+ "generated = generate_text_pipeline(prompt)\n",
+ "print(generated)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "prompt = \"life seems to be slipping away from the hands yaar\"\n",
+ "generated = generate_text_pipeline(prompt)\n",
+ "generated"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "rOKFrWiiWj_W",
+ "outputId": "b5b4551f-704c-4694-bfaa-6f314b216d10"
+ },
+ "execution_count": 48,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "[[{'generated_text': \"life seems to be slipping away from the hands yaar. Everything feels unpredictable, anxiety peaks and dips hard.\\n\\nI'm so sorry to hear it. Sometimes it feels like I'm just going through the motions.\"}]]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 48
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## finetune fp16"
+ ],
+ "metadata": {
+ "id": "A4SPVRqTHh6V"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### preparing for finetuning using loaded dataset"
+ ],
+ "metadata": {
+ "id": "z2L9oG8KWHhM"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import torch\n",
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
+ "\n",
+ "base_model = \"google/gemma-3-270m-it\"\n",
+ "\n",
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ "\n",
+ "# Load model and tokenizer\n",
+ "model = AutoModelForCausalLM.from_pretrained(\n",
+ " base_model,\n",
+ " attn_implementation=\"eager\",\n",
+ " torch_dtype=torch.float16,\n",
+ ").to(device)\n",
+ "tokenizer = AutoTokenizer.from_pretrained(base_model)\n",
+ "\n",
+ "print(f\"Device: {model.device}\")\n",
+ "print(f\"DType: {model.dtype}\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "VnGWuxGSLUCS",
+ "outputId": "cc4257f4-d2ea-4293-c478-45f959ca824d"
+ },
+ "execution_count": 8,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Device: cuda:0\n",
+ "DType: torch.float16\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from peft import LoraConfig, get_peft_model\n",
+ "\n",
+ "peft_config = LoraConfig(\n",
+ " r=64,\n",
+ " lora_alpha=64,\n",
+ " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
+ " \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
+ " lora_dropout=0.2,\n",
+ " bias=\"none\",\n",
+ " task_type=\"CAUSAL_LM\"\n",
+ ")"
+ ],
+ "metadata": {
+ "id": "6k5oeCgdLpXF"
+ },
+ "execution_count": 9,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "peft_model = get_peft_model(model, peft_config)\n",
+ "peft_model.print_trainable_parameters()"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "jen4hTLhLpMc",
+ "outputId": "5cbecd51-f874-4a9b-dffd-2a509aba2175"
+ },
+ "execution_count": 10,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "trainable params: 15,187,968 || all params: 283,286,144 || trainable%: 5.3614\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install trl"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 0
+ },
+ "collapsed": true,
+ "id": "o6ve0bjLMMKA",
+ "outputId": "72965f46-469c-4b91-f30a-b9692228c4ab"
+ },
+ "execution_count": 34,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Collecting trl\n",
+ " Downloading trl-0.21.0-py3-none-any.whl.metadata (11 kB)\n",
+ "Requirement already satisfied: accelerate>=1.4.0 in /usr/local/lib/python3.11/dist-packages (from trl) (1.10.0)\n",
+ "Requirement already satisfied: datasets>=3.0.0 in /usr/local/lib/python3.11/dist-packages (from trl) (4.0.0)\n",
+ "Requirement already satisfied: transformers>=4.55.0 in /usr/local/lib/python3.11/dist-packages (from trl) (4.55.1)\n",
+ "Requirement already satisfied: numpy<3.0.0,>=1.17 in /usr/local/lib/python3.11/dist-packages (from accelerate>=1.4.0->trl) (2.0.2)\n",
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from accelerate>=1.4.0->trl) (25.0)\n",
+ "Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate>=1.4.0->trl) (5.9.5)\n",
+ "Requirement already satisfied: pyyaml in /usr/local/lib/python3.11/dist-packages (from accelerate>=1.4.0->trl) (6.0.2)\n",
+ "Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from accelerate>=1.4.0->trl) (2.6.0+cu124)\n",
+ "Requirement already satisfied: huggingface_hub>=0.21.0 in /usr/local/lib/python3.11/dist-packages (from accelerate>=1.4.0->trl) (0.34.4)\n",
+ "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from accelerate>=1.4.0->trl) (0.6.2)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from datasets>=3.0.0->trl) (3.18.0)\n",
+ "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.11/dist-packages (from datasets>=3.0.0->trl) (18.1.0)\n",
+ "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.11/dist-packages (from datasets>=3.0.0->trl) (0.3.8)\n",
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (from datasets>=3.0.0->trl) (2.2.2)\n",
+ "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.11/dist-packages (from datasets>=3.0.0->trl) (2.32.3)\n",
+ "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.11/dist-packages (from datasets>=3.0.0->trl) (4.67.1)\n",
+ "Requirement already satisfied: xxhash in /usr/local/lib/python3.11/dist-packages (from datasets>=3.0.0->trl) (3.5.0)\n",
+ "Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.11/dist-packages (from datasets>=3.0.0->trl) (0.70.16)\n",
+ "Requirement already satisfied: fsspec<=2025.3.0,>=2023.1.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets>=3.0.0->trl) (2025.3.0)\n",
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers>=4.55.0->trl) (2024.11.6)\n",
+ "Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.11/dist-packages (from transformers>=4.55.0->trl) (0.21.4)\n",
+ "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets>=3.0.0->trl) (3.12.15)\n",
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub>=0.21.0->accelerate>=1.4.0->trl) (4.14.1)\n",
+ "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub>=0.21.0->accelerate>=1.4.0->trl) (1.1.7)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets>=3.0.0->trl) (3.4.3)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets>=3.0.0->trl) (3.10)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets>=3.0.0->trl) (2.5.0)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets>=3.0.0->trl) (2025.8.3)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate>=1.4.0->trl) (3.5)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate>=1.4.0->trl) (3.1.6)\n",
+ "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->accelerate>=1.4.0->trl)\n",
+ " Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
+ "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->accelerate>=1.4.0->trl)\n",
+ " Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
+ "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0.0->accelerate>=1.4.0->trl)\n",
+ " Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
+ "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0.0->accelerate>=1.4.0->trl)\n",
+ " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
+ "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0.0->accelerate>=1.4.0->trl)\n",
+ " Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
+ "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=2.0.0->accelerate>=1.4.0->trl)\n",
+ " Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
+ "Collecting nvidia-curand-cu12==10.3.5.147 (from torch>=2.0.0->accelerate>=1.4.0->trl)\n",
+ " Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
+ "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch>=2.0.0->accelerate>=1.4.0->trl)\n",
+ " Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
+ "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch>=2.0.0->accelerate>=1.4.0->trl)\n",
+ " Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
+ "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate>=1.4.0->trl) (0.6.2)\n",
+ "Collecting nvidia-nccl-cu12==2.21.5 (from torch>=2.0.0->accelerate>=1.4.0->trl)\n",
+ " Downloading nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl.metadata (1.8 kB)\n",
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate>=1.4.0->trl) (12.4.127)\n",
+ "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch>=2.0.0->accelerate>=1.4.0->trl)\n",
+ " Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
+ "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate>=1.4.0->trl) (3.2.0)\n",
+ "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate>=1.4.0->trl) (1.13.1)\n",
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=2.0.0->accelerate>=1.4.0->trl) (1.3.0)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas->datasets>=3.0.0->trl) (2.9.0.post0)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas->datasets>=3.0.0->trl) (2025.2)\n",
+ "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas->datasets>=3.0.0->trl) (2025.2)\n",
+ "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets>=3.0.0->trl) (2.6.1)\n",
+ "Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets>=3.0.0->trl) (1.4.0)\n",
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets>=3.0.0->trl) (25.3.0)\n",
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets>=3.0.0->trl) (1.7.0)\n",
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets>=3.0.0->trl) (6.6.4)\n",
+ "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets>=3.0.0->trl) (0.3.2)\n",
+ "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets>=3.0.0->trl) (1.20.1)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas->datasets>=3.0.0->trl) (1.17.0)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch>=2.0.0->accelerate>=1.4.0->trl) (3.0.2)\n",
+ "Downloading trl-0.21.0-py3-none-any.whl (511 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m511.9/511.9 kB\u001b[0m \u001b[31m29.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m65.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m36.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m47.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m784.8 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m1.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m19.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl (188.7 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m188.7/188.7 MB\u001b[0m \u001b[31m1.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m37.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, trl\n",
+ " Attempting uninstall: nvidia-nvjitlink-cu12\n",
+ " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n",
+ " Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n",
+ " Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n",
+ " Attempting uninstall: nvidia-nccl-cu12\n",
+ " Found existing installation: nvidia-nccl-cu12 2.23.4\n",
+ " Uninstalling nvidia-nccl-cu12-2.23.4:\n",
+ " Successfully uninstalled nvidia-nccl-cu12-2.23.4\n",
+ " Attempting uninstall: nvidia-curand-cu12\n",
+ " Found existing installation: nvidia-curand-cu12 10.3.6.82\n",
+ " Uninstalling nvidia-curand-cu12-10.3.6.82:\n",
+ " Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n",
+ " Attempting uninstall: nvidia-cufft-cu12\n",
+ " Found existing installation: nvidia-cufft-cu12 11.2.3.61\n",
+ " Uninstalling nvidia-cufft-cu12-11.2.3.61:\n",
+ " Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n",
+ " Attempting uninstall: nvidia-cuda-runtime-cu12\n",
+ " Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n",
+ " Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n",
+ " Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n",
+ " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n",
+ " Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n",
+ " Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n",
+ " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n",
+ " Attempting uninstall: nvidia-cuda-cupti-cu12\n",
+ " Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n",
+ " Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n",
+ " Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n",
+ " Attempting uninstall: nvidia-cublas-cu12\n",
+ " Found existing installation: nvidia-cublas-cu12 12.5.3.2\n",
+ " Uninstalling nvidia-cublas-cu12-12.5.3.2:\n",
+ " Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n",
+ " Attempting uninstall: nvidia-cusparse-cu12\n",
+ " Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n",
+ " Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n",
+ " Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n",
+ " Attempting uninstall: nvidia-cudnn-cu12\n",
+ " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n",
+ " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n",
+ " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n",
+ " Attempting uninstall: nvidia-cusolver-cu12\n",
+ " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n",
+ " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n",
+ " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n",
+ "Successfully installed nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nccl-cu12-2.21.5 nvidia-nvjitlink-cu12-12.4.127 trl-0.21.0\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "application/vnd.colab-display-data+json": {
+ "pip_warning": {
+ "packages": [
+ "nvidia"
+ ]
+ },
+ "id": "2964d28e12464a12bc40ec5203dd862c"
+ }
+ },
+ "metadata": {}
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from trl import SFTTrainer, SFTConfig\n",
+ "\n",
+ "trainer = SFTTrainer(\n",
+ " model = peft_model,\n",
+ " train_dataset = gemma3_dataset,\n",
+ " args = SFTConfig(\n",
+ " dataset_text_field=\"text\",\n",
+ " per_device_train_batch_size=32,\n",
+ " gradient_accumulation_steps=6,\n",
+ " warmup_steps=10,\n",
+ " num_train_epochs=10,\n",
+ " learning_rate=5e-5,\n",
+ " logging_strategy=\"steps\",\n",
+ " logging_steps=1,\n",
+ " optim=\"adamw_torch_fused\",\n",
+ " weight_decay=0.01,\n",
+ " lr_scheduler_type=\"linear\",\n",
+ " seed=3407,\n",
+ " fp16=False,\n",
+ " output_dir=\"outputs\",\n",
+ " report_to=\"none\",\n",
+ " )\n",
+ "\n",
+ ")"
+ ],
+ "metadata": {
+ "id": "YmPulIk9LpKN"
+ },
+ "execution_count": 14,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "trainer.train()"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 425
+ },
+ "id": "meBnovQTLpGk",
+ "outputId": "57611232-9685-45e2-f7e2-80ffdd29e944"
+ },
+ "execution_count": 29,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ ""
+ ],
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [10/10 00:47, Epoch 10/10]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Step | \n",
+ " Training Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 1.106000 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 1.107900 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 1.095800 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 1.075000 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 1.053600 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 1.041800 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 1.028700 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 1.007100 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.973700 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.933100 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "TrainOutput(global_step=10, training_loss=1.042256236076355, metrics={'train_runtime': 52.4478, 'train_samples_per_second': 22.88, 'train_steps_per_second': 0.191, 'total_flos': 42488815650816.0, 'train_loss': 1.042256236076355})"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 29
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "trainer.train() # 8-9th iteration so 80-90 epochs"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 425
+ },
+ "id": "tXoxUWhGO7Gg",
+ "outputId": "08214f4e-9387-4a7a-adde-e5e4a09ed40e"
+ },
+ "execution_count": 30,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ ""
+ ],
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [10/10 00:48, Epoch 10/10]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Step | \n",
+ " Training Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.897500 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.892300 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.885000 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.867700 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.839400 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.824900 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.813100 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.798900 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.770800 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.737800 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "TrainOutput(global_step=10, training_loss=0.8327455043792724, metrics={'train_runtime': 53.2992, 'train_samples_per_second': 22.514, 'train_steps_per_second': 0.188, 'total_flos': 42488815650816.0, 'train_loss': 0.8327455043792724})"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 30
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Saving model fp16\n"
+ ],
+ "metadata": {
+ "id": "6o4qtUtKWMvd"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "trainer.model.save_pretrained(\"gemma-genz-270M-peft-fp16\")\n",
+ "tokenizer.save_pretrained(\"gemma-genz-270M-peft-fp16\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "aSUb7g0wVkXo",
+ "outputId": "714fe605-a2e4-4f92-be79-10489a776ba0"
+ },
+ "execution_count": 31,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "('gemma-genz-270M-peft-fp16/tokenizer_config.json',\n",
+ " 'gemma-genz-270M-peft-fp16/special_tokens_map.json',\n",
+ " 'gemma-genz-270M-peft-fp16/chat_template.jinja',\n",
+ " 'gemma-genz-270M-peft-fp16/tokenizer.model',\n",
+ " 'gemma-genz-270M-peft-fp16/added_tokens.json',\n",
+ " 'gemma-genz-270M-peft-fp16/tokenizer.json')"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 31
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# saving to local\n",
+ "import shutil\n",
+ "\n",
+ "# Path to your saved model\n",
+ "model_folder = \"gemma-genz-270M-peft-fp16\"\n",
+ "zip_name = \"gemma-genz-270M-peft-fp16.zip\"\n",
+ "\n",
+ "# Create zip\n",
+ "shutil.make_archive(\"gemma-genz-270M-peft-fp16\", 'zip', model_folder)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 36
+ },
+ "id": "tA94Um0JVsZh",
+ "outputId": "cfa2480f-ea78-43bc-b1c6-35ced736f4e7"
+ },
+ "execution_count": 32,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "'/content/gemma-genz-270M-peft-fp16.zip'"
+ ],
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "string"
+ }
+ },
+ "metadata": {},
+ "execution_count": 32
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from google.colab import files\n",
+ "\n",
+ "files.download(zip_name)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "id": "leOU16GSV2eR",
+ "outputId": "13dbd777-f02c-4d77-9db0-c55b681ac2ec"
+ },
+ "execution_count": 33,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ ""
+ ],
+ "application/javascript": [
+ "\n",
+ " async function download(id, filename, size) {\n",
+ " if (!google.colab.kernel.accessAllowed) {\n",
+ " return;\n",
+ " }\n",
+ " const div = document.createElement('div');\n",
+ " const label = document.createElement('label');\n",
+ " label.textContent = `Downloading \"${filename}\": `;\n",
+ " div.appendChild(label);\n",
+ " const progress = document.createElement('progress');\n",
+ " progress.max = size;\n",
+ " div.appendChild(progress);\n",
+ " document.body.appendChild(div);\n",
+ "\n",
+ " const buffers = [];\n",
+ " let downloaded = 0;\n",
+ "\n",
+ " const channel = await google.colab.kernel.comms.open(id);\n",
+ " // Send a message to notify the kernel that we're ready.\n",
+ " channel.send({})\n",
+ "\n",
+ " for await (const message of channel.messages) {\n",
+ " // Send a message to notify the kernel that we're ready.\n",
+ " channel.send({})\n",
+ " if (message.buffers) {\n",
+ " for (const buffer of message.buffers) {\n",
+ " buffers.push(buffer);\n",
+ " downloaded += buffer.byteLength;\n",
+ " progress.value = downloaded;\n",
+ " }\n",
+ " }\n",
+ " }\n",
+ " const blob = new Blob(buffers, {type: 'application/binary'});\n",
+ " const a = document.createElement('a');\n",
+ " a.href = window.URL.createObjectURL(blob);\n",
+ " a.download = filename;\n",
+ " div.appendChild(a);\n",
+ " a.click();\n",
+ " div.remove();\n",
+ " }\n",
+ " "
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ ""
+ ],
+ "application/javascript": [
+ "download(\"download_ae1f806c-6786-40e8-99fe-68730410115f\", \"gemma-genz-270M-peft-fp16.zip\", 64093541)"
+ ]
+ },
+ "metadata": {}
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Inference (fp16)"
+ ],
+ "metadata": {
+ "id": "TptYreUvWQd0"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import torch\n",
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline\n",
+ "from peft import PeftModel\n",
+ "\n",
+ "device = 0 if torch.cuda.is_available() else -1 # pipeline uses int device\n",
+ "\n",
+ "base_model_name = \"google/gemma-3-270m-it\"\n",
+ "base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float32)\n",
+ "\n",
+ "peft_model_path = \"/content/gemma-genz-270M-peft-fp16\"\n",
+ "model = PeftModel.from_pretrained(base_model, peft_model_path)\n",
+ "model.eval()\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(peft_model_path)\n",
+ "\n",
+ "text_gen_pipeline = pipeline(\n",
+ " \"text-generation\",\n",
+ " model=model,\n",
+ " tokenizer=tokenizer,\n",
+ " device=device\n",
+ ")\n",
+ "\n",
+ "def generate_text_pipeline_fp16(prompts, max_new_tokens=100, temperature=1.0, top_k=50, top_p=0.95, do_sample=True):\n",
+ " \"\"\"\n",
+ " Generate text using the preloaded PEFT LoRA model via Hugging Face pipeline.\n",
+ "\n",
+ " Args:\n",
+ " prompts (str or list[str]): Single prompt or list of prompts.\n",
+ " max_new_tokens (int): Maximum tokens to generate beyond input.\n",
+ " temperature (float): Sampling temperature.\n",
+ " top_k (int): Top-k sampling.\n",
+ " top_p (float): Top-p sampling (nucleus).\n",
+ " do_sample (bool): Whether to sample or use greedy decoding.\n",
+ "\n",
+ " Returns:\n",
+ " list[str]: Generated text(s).\n",
+ " \"\"\"\n",
+ " if isinstance(prompts, str):\n",
+ " prompts = [prompts]\n",
+ "\n",
+ " outputs = text_gen_pipeline(\n",
+ " prompts,\n",
+ " max_new_tokens=max_new_tokens,\n",
+ " temperature=temperature,\n",
+ " top_k=top_k,\n",
+ " top_p=top_p,\n",
+ " do_sample=do_sample\n",
+ " )\n",
+ "\n",
+ " return outputs"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "HnpB3W8-WUO0",
+ "outputId": "6f482664-ff79-49b4-d5b8-470a63c08991"
+ },
+ "execution_count": 54,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "Device set to use cuda:0\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "prompt = \"job market is dead bro!\"\n",
+ "generated = generate_text_pipeline_fp16(prompt, max_new_tokens=50)\n",
+ "generated"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "klat4M0_aaqE",
+ "outputId": "2b04824d-f8e2-46d9-daa3-9fea5b2bb8a3"
+ },
+ "execution_count": 67,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "[[{'generated_text': 'job market is dead bro! Opportunities are still scarce, but demand outstrihes supply. #jobsearching #careeradvice #FOMO\\n'}]]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 67
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "prompt = \"life seems to be slipping away from the hands yaar\"\n",
+ "generated = generate_text_pipeline(prompt)\n",
+ "generated"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "-uYdAATeaS2s",
+ "outputId": "eccea201-9064-4b74-934f-1ec80f9ebaa7"
+ },
+ "execution_count": 63,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "[[{'generated_text': 'life seems to be slipping away from the hands yaar. I want to do something meaningful, not just survive. This is the only way to ground myself. NO. Thanks. #MeaningfulLiving #MemoryLane\\n'}]]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 63
+ }
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [
+ "Q6YuEi6ue55_",
+ "UTgfA4oXi3f1",
+ "8fFG1k8kYKJc",
+ "NHOCVttmi9n-",
+ "OoonXeFSjGqO",
+ "4QRB-MGhjRbV"
+ ],
+ "gpuType": "T4",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "1b8dc73f85514bf0b77f97e5ad581950": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "1cafc9b47ff34c729a68ed7aef27d402": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "1e275167e6344d56913e08ccfd7dc7f3": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "3934582e447d40e099a5f2ec526f0cc3": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "VBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "VBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "VBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_cc94bf3a178f4ab2afb82bf21a2562fc",
+ "IPY_MODEL_9029fc63792d48f0933f7246aac21ad9",
+ "IPY_MODEL_b97936c82801415d9815fcc8f360c3b0",
+ "IPY_MODEL_90d36455009d4d9f867706ca01852908",
+ "IPY_MODEL_48a35b4471b74efca0611bc911ac9516"
+ ],
+ "layout": "IPY_MODEL_f3b206fa8b0e40a0b7d1d0a5c6dfb821"
+ }
+ },
+ "3b2b7983b12a418fb4ebbac7692a9b52": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "48a35b4471b74efca0611bc911ac9516": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_1cafc9b47ff34c729a68ed7aef27d402",
+ "placeholder": "",
+ "style": "IPY_MODEL_3b2b7983b12a418fb4ebbac7692a9b52",
+ "value": "\nPro Tip: If you don't already have one, you can create a dedicated\n'notebooks' token with 'write' access, that you can then easily reuse for all\nnotebooks.