kingabzpro commited on
Commit
5f7a806
·
verified ·
1 Parent(s): 43aa2d2

Fine-tuning script

Browse files
Files changed (1) hide show
  1. Fine_tuning_llama4 (Original).ipynb +702 -0
Fine_tuning_llama4 (Original).ipynb ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "7d7feba1-086e-4e5f-9fd8-570ecacc5205",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Setting up"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 3,
14
+ "id": "6b17d385-ca37-4e8a-937d-bebbd221a386",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "%%capture\n",
19
+ "%pip install -U datasets \n",
20
+ "%pip install -U accelerate \n",
21
+ "%pip install -U peft \n",
22
+ "%pip install -U trl \n",
23
+ "%pip install -U bitsandbytes\n",
24
+ "%pip install huggingface_hub[hf_xet]"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 4,
30
+ "id": "18b1954e-46e9-4877-940d-ff424cef921e",
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "%%capture\n",
35
+ "!pip install transformers==4.51.0"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": 2,
41
+ "id": "cca3c6fa-0f55-4944-8cf1-be8f130cc1e3",
42
+ "metadata": {},
43
+ "outputs": [
44
+ {
45
+ "data": {
46
+ "application/vnd.jupyter.widget-view+json": {
47
+ "model_id": "993b193c89c14e919fe0841db78c7011",
48
+ "version_major": 2,
49
+ "version_minor": 0
50
+ },
51
+ "text/plain": [
52
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
53
+ ]
54
+ },
55
+ "metadata": {},
56
+ "output_type": "display_data"
57
+ }
58
+ ],
59
+ "source": [
60
+ "from huggingface_hub import login\n",
61
+ "import os\n",
62
+ "\n",
63
+ "hf_token = os.environ.get(\"HF_TOKEN\")\n",
64
+ "login(hf_token)"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "markdown",
69
+ "id": "086b243c-1616-4664-887c-ad803f0a09fd",
70
+ "metadata": {},
71
+ "source": [
72
+ "## Loading the model and tokenizer"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": 1,
78
+ "id": "a1fa3f81-a593-49de-bc35-7989e74f4d60",
79
+ "metadata": {
80
+ "scrolled": true
81
+ },
82
+ "outputs": [
83
+ {
84
+ "data": {
85
+ "application/vnd.jupyter.widget-view+json": {
86
+ "model_id": "7af36e4cb39a4ac699810bad1a4db9d0",
87
+ "version_major": 2,
88
+ "version_minor": 0
89
+ },
90
+ "text/plain": [
91
+ "Loading checkpoint shards: 0%| | 0/50 [00:00<?, ?it/s]"
92
+ ]
93
+ },
94
+ "metadata": {},
95
+ "output_type": "display_data"
96
+ }
97
+ ],
98
+ "source": [
99
+ "import os\n",
100
+ "import torch\n",
101
+ "from transformers import AutoTokenizer, Llama4ForConditionalGeneration, BitsAndBytesConfig\n",
102
+ "\n",
103
+ "\n",
104
+ "model_id = \"meta-llama/Llama-4-Scout-17B-16E-Instruct\"\n",
105
+ "\n",
106
+ "bnb_config = BitsAndBytesConfig(\n",
107
+ " load_in_4bit=True,\n",
108
+ " bnb_4bit_use_double_quant=False,\n",
109
+ " bnb_4bit_quant_type=\"nf4\",\n",
110
+ " bnb_4bit_compute_dtype=torch.bfloat16,\n",
111
+ ")\n",
112
+ "\n",
113
+ "\n",
114
+ "model = Llama4ForConditionalGeneration.from_pretrained(\n",
115
+ " model_id,\n",
116
+ " device_map=\"auto\", \n",
117
+ " torch_dtype=torch.bfloat16,\n",
118
+ " quantization_config=bnb_config,\n",
119
+ " trust_remote_code=True,\n",
120
+ ")\n",
121
+ "\n",
122
+ "model.config.use_cache = False\n",
123
+ "model.config.pretraining_tp = 1"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": 2,
129
+ "id": "e7eb2d9c-2862-497b-801d-735fe4276233",
130
+ "metadata": {},
131
+ "outputs": [],
132
+ "source": [
133
+ "# Load tokenizer\n",
134
+ "tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": 3,
140
+ "id": "08ee3942-5750-4196-9cd3-a9da9abdd3b6",
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "# model.push_to_hub(\"Llama-4-Scout-17B-16E-Instruct-bnb-4bit\")\n",
145
+ "# tokenizer.push_to_hub(\"Llama-4-Scout-17B-16E-Instruct-bnb-4bit\")"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": 4,
151
+ "id": "dde054e4-a549-4e3a-a057-6eb739035646",
152
+ "metadata": {},
153
+ "outputs": [
154
+ {
155
+ "name": "stdout",
156
+ "output_type": "stream",
157
+ "text": [
158
+ "Wed Apr 9 18:45:00 2025 \n",
159
+ "+-----------------------------------------------------------------------------------------+\n",
160
+ "| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 |\n",
161
+ "|-----------------------------------------+------------------------+----------------------+\n",
162
+ "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
163
+ "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
164
+ "| | | MIG M. |\n",
165
+ "|=========================================+========================+======================|\n",
166
+ "| 0 NVIDIA H200 On | 00000000:AA:00.0 Off | 0 |\n",
167
+ "| N/A 33C P0 115W / 700W | 54334MiB / 143771MiB | 0% Default |\n",
168
+ "| | | Disabled |\n",
169
+ "+-----------------------------------------+------------------------+----------------------+\n",
170
+ "| 1 NVIDIA H200 On | 00000000:BA:00.0 Off | 0 |\n",
171
+ "| N/A 30C P0 115W / 700W | 63710MiB / 143771MiB | 0% Default |\n",
172
+ "| | | Disabled |\n",
173
+ "+-----------------------------------------+------------------------+----------------------+\n",
174
+ "| 2 NVIDIA H200 On | 00000000:CA:00.0 Off | 0 |\n",
175
+ "| N/A 33C P0 118W / 700W | 77494MiB / 143771MiB | 0% Default |\n",
176
+ "| | | Disabled |\n",
177
+ "+-----------------------------------------+------------------------+----------------------+\n",
178
+ " \n",
179
+ "+-----------------------------------------------------------------------------------------+\n",
180
+ "| Processes: |\n",
181
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
182
+ "| ID ID Usage |\n",
183
+ "|=========================================================================================|\n",
184
+ "+-----------------------------------------------------------------------------------------+\n"
185
+ ]
186
+ }
187
+ ],
188
+ "source": [
189
+ "!nvidia-smi"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "markdown",
194
+ "id": "01e4c418-4538-4140-bbdb-63563d952c1e",
195
+ "metadata": {},
196
+ "source": [
197
+ "## Loading and processing the dataset"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": 5,
203
+ "id": "557f53e8-b334-4cf1-99da-5ec2f2b39ef6",
204
+ "metadata": {},
205
+ "outputs": [],
206
+ "source": [
207
+ "train_prompt_style = \"\"\"Below is an instruction that describes a task, paired with an input that provides further context. \n",
208
+ "Write a response that appropriately completes the request. \n",
209
+ "Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.\n",
210
+ "\n",
211
+ "### Instruction:\n",
212
+ "You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. \n",
213
+ "Please answer the following medical question. \n",
214
+ "\n",
215
+ "### Question:\n",
216
+ "{}\n",
217
+ "\n",
218
+ "### Response:\n",
219
+ "<think>\n",
220
+ "{}\n",
221
+ "</think>\n",
222
+ "{}\"\"\""
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "code",
227
+ "execution_count": 6,
228
+ "id": "267ead23-7502-4b02-9fb7-f213081dbea6",
229
+ "metadata": {},
230
+ "outputs": [],
231
+ "source": [
232
+ "EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN\n",
233
+ "\n",
234
+ "def formatting_prompts_func(examples):\n",
235
+ " inputs = examples[\"Question\"]\n",
236
+ " complex_cots = examples[\"Complex_CoT\"]\n",
237
+ " outputs = examples[\"Response\"]\n",
238
+ " texts = []\n",
239
+ " for question, cot, response in zip(inputs, complex_cots, outputs):\n",
240
+ " # Append the EOS token to the response if it's not already there\n",
241
+ " if not response.endswith(tokenizer.eos_token):\n",
242
+ " response += tokenizer.eos_token\n",
243
+ " text = train_prompt_style.format(question, cot, response)\n",
244
+ " texts.append(text)\n",
245
+ " return {\"text\": texts}\n"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": 7,
251
+ "id": "a9de39ca-e687-4f56-8f76-2f7f69f81ac1",
252
+ "metadata": {},
253
+ "outputs": [
254
+ {
255
+ "data": {
256
+ "text/plain": [
257
+ "\"Below is an instruction that describes a task, paired with an input that provides further context. \\nWrite a response that appropriately completes the request. \\nBefore answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.\\n\\n### Instruction:\\nYou are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. \\nPlease answer the following medical question. \\n\\n### Question:\\nA 61-year-old woman with a long history of involuntary urine loss during activities like coughing or sneezing but no leakage at night undergoes a gynecological exam and Q-tip test. Based on these findings, what would cystometry most likely reveal about her residual volume and detrusor contractions?\\n\\n### Response:\\n<think>\\nOkay, let's think about this step by step. There's a 61-year-old woman here who's been dealing with involuntary urine leakages whenever she's doing something that ups her abdominal pressure like coughing or sneezing. This sounds a lot like stress urinary incontinence to me. Now, it's interesting that she doesn't have any issues at night; she isn't experiencing leakage while sleeping. This likely means her bladder's ability to hold urine is fine when she isn't under physical stress. Hmm, that's a clue that we're dealing with something related to pressure rather than a bladder muscle problem. \\n\\nThe fact that she underwent a Q-tip test is intriguing too. This test is usually done to assess urethral mobility. In stress incontinence, a Q-tip might move significantly, showing urethral hypermobility. This kind of movement often means there's a weakness in the support structures that should help keep the urethra closed during increases in abdominal pressure. So, that's aligning well with stress incontinence.\\n\\nNow, let's think about what would happen during cystometry. Since stress incontinence isn't usually about sudden bladder contractions, I wouldn't expect to see involuntary detrusor contractions during this test. Her bladder isn't spasming or anything; it's more about the support structure failing under stress. Plus, she likely empties her bladder completely because stress incontinence doesn't typically involve incomplete emptying. So, her residual volume should be pretty normal. \\n\\nAll in all, it seems like if they do a cystometry on her, it will likely show a normal residual volume and no involuntary contractions. Yup, I think that makes sense given her symptoms and the typical presentations of stress urinary incontinence.\\n</think>\\nCystometry in this case of stress urinary incontinence would most likely reveal a normal post-void residual volume, as stress incontinence typically does not involve issues with bladder emptying. Additionally, since stress urinary incontinence is primarily related to physical exertion and not an overactive bladder, you would not expect to see any involuntary detrusor contractions during the test.<|eot|>\""
258
+ ]
259
+ },
260
+ "execution_count": 7,
261
+ "metadata": {},
262
+ "output_type": "execute_result"
263
+ }
264
+ ],
265
+ "source": [
266
+ "from datasets import load_dataset\n",
267
+ "\n",
268
+ "\n",
269
+ "dataset = load_dataset(\"FreedomIntelligence/medical-o1-reasoning-SFT\",\"en\", split = \"train[0:500]\",trust_remote_code=True)\n",
270
+ "dataset = dataset.map(formatting_prompts_func, batched = True,)\n",
271
+ "dataset[\"text\"][0]"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": 8,
277
+ "id": "addc3c2a-107e-412c-ac2d-471a2e1782fb",
278
+ "metadata": {},
279
+ "outputs": [],
280
+ "source": [
281
+ "from transformers import DataCollatorForLanguageModeling\n",
282
+ "\n",
283
+ "data_collator = DataCollatorForLanguageModeling(\n",
284
+ " tokenizer=tokenizer,\n",
285
+ " mlm=False\n",
286
+ ")"
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "markdown",
291
+ "id": "ad649bed-3a2c-4709-8120-4f76ae4a1e84",
292
+ "metadata": {},
293
+ "source": [
294
+ "## Model inference before fine-tuning"
295
+ ]
296
+ },
297
+ {
298
+ "cell_type": "code",
299
+ "execution_count": 9,
300
+ "id": "c72d112a-6da1-4c5c-a310-7271526bd838",
301
+ "metadata": {},
302
+ "outputs": [],
303
+ "source": [
304
+ "prompt_style = \"\"\"Below is an instruction that describes a task, paired with an input that provides further context. \n",
305
+ "Write a response that appropriately completes the request. \n",
306
+ "Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.\n",
307
+ "\n",
308
+ "### Instruction:\n",
309
+ "You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. \n",
310
+ "Please answer the following medical question. \n",
311
+ "\n",
312
+ "### Question:\n",
313
+ "{}\n",
314
+ "\n",
315
+ "### Response:\n",
316
+ "<think>{}\"\"\""
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "code",
321
+ "execution_count": 10,
322
+ "id": "5905dc8b-e22d-4a1a-8a9f-6c3490bff33b",
323
+ "metadata": {},
324
+ "outputs": [
325
+ {
326
+ "name": "stdout",
327
+ "output_type": "stream",
328
+ "text": [
329
+ "\n",
330
+ "<think>assistant\n",
331
+ "\n",
332
+ "To approach this question, let's break down the key elements provided and analyze them step by step:\n",
333
+ "\n",
334
+ "1. **Symptoms**: The patient experiences involuntary urine loss during activities like coughing or sneezing but has no leakage at night. This pattern of urinary incontinence is suggestive of stress urinary incontinence (SUI), which is characterized by the involuntary leakage of urine on effort or exertion, or on sneezing or coughing.\n",
335
+ "\n",
336
+ "2. **Diagnostic Tests Mentioned**:\n",
337
+ " - **Gynecological Exam**: This is likely performed to assess the pelvic anatomy, including the position and support of the urethra and bladder neck, and to check for any pelvic organ prolapse.\n",
338
+ " - **Q-tip Test**: This test is used to assess urethral mobility. A Q-tip (cotton swab) is inserted into the urethra, and its angle of movement is measured. Increased mobility (an angle change of more than 30 degrees) is often associated with stress urinary incontinence.\n",
339
+ "\n",
340
+ "3. **Cystometry (Cystometrogram)**: This test measures the pressure within the bladder during filling and helps assess bladder function, including the residual volume and detrusor muscle contractions. The detrusor muscle is the smooth muscle in the wall of the bladder that contracts to allow urine to be expelled.\n",
341
+ "\n",
342
+ "Given that the patient likely has stress urinary incontinence (SUI) based on her symptoms:\n",
343
+ "- **Residual Volume**: In patients with SUI, the bladder usually functions normally, and thus, the residual volume (the amount of urine left in the bladder after urination) is typically not significantly affected. Therefore, one would expect the residual volume to be normal or near-normal.\n",
344
+ " \n",
345
+ "- **Detrusor Contractions**: In SUI, the problem primarily lies with the urethral sphincter mechanism and support rather than with the detrusor muscle itself. Hence, detrusor contractions are usually normal. The patient does not have symptoms suggestive of an overactive bladder (like urgency, urge incontinence, or nocturia), which would be more indicative of detrusor overactivity.\n",
346
+ "\n",
347
+ "Based on this analysis, cystometry in this patient would most likely reveal:\n",
348
+ "- A **normal residual volume**, as her symptoms do not suggest a problem with bladder emptying.\n",
349
+ "- **Normal detrusor contractions**, as her condition (stress urinary incontinence) primarily involves issues with urethral support and continence mechanisms rather than detrusor function.\n",
350
+ "\n",
351
+ "Therefore, cystometry would likely show that she has a normal residual volume and normal detrusor contractions. \n",
352
+ "\n",
353
+ "</think>\n",
354
+ "\n",
355
+ "The final answer is: $\\boxed{Normal residual volume and normal detrusor contractions}$\n"
356
+ ]
357
+ }
358
+ ],
359
+ "source": [
360
+ "question = dataset[0]['Question']\n",
361
+ "inputs = tokenizer(\n",
362
+ " [prompt_style.format(question, \"\") + tokenizer.eos_token],\n",
363
+ " return_tensors=\"pt\"\n",
364
+ ").to(\"cuda\")\n",
365
+ "\n",
366
+ "outputs = model.generate(\n",
367
+ " input_ids=inputs.input_ids,\n",
368
+ " attention_mask=inputs.attention_mask,\n",
369
+ " max_new_tokens=1200,\n",
370
+ " eos_token_id=tokenizer.eos_token_id,\n",
371
+ " use_cache=True,\n",
372
+ ")\n",
373
+ "response = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
374
+ "print(response[0].split(\"### Response:\")[1])"
375
+ ]
376
+ },
377
+ {
378
+ "cell_type": "markdown",
379
+ "id": "985599a2-9a6a-4629-9a95-0893e044eb8a",
380
+ "metadata": {},
381
+ "source": [
382
+ "## Setting up the model"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": 11,
388
+ "id": "d332cefc-9b12-4121-a355-5590e7fd6a4f",
389
+ "metadata": {},
390
+ "outputs": [],
391
+ "source": [
392
+ "from peft import LoraConfig, get_peft_model\n",
393
+ "\n",
394
+ "# LoRA config\n",
395
+ "peft_config = LoraConfig(\n",
396
+ " lora_alpha=16, # Scaling factor for LoRA\n",
397
+ " lora_dropout=0.05, # Add slight dropout for regularization\n",
398
+ " r=64, # Rank of the LoRA update matrices\n",
399
+ " bias=\"none\", # No bias reparameterization\n",
400
+ " task_type=\"CAUSAL_LM\", # Task type: Causal Language Modeling\n",
401
+ " target_modules=[\n",
402
+ " \"q_proj\",\n",
403
+ " \"k_proj\",\n",
404
+ " \"v_proj\",\n",
405
+ " \"o_proj\",\n",
406
+ " \"gate_proj\",\n",
407
+ " \"up_proj\",\n",
408
+ " \"down_proj\",\n",
409
+ " ], # Target modules for LoRA\n",
410
+ ")\n",
411
+ "\n",
412
+ "model = get_peft_model(model, peft_config)"
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "code",
417
+ "execution_count": 12,
418
+ "id": "7498899d-0986-4330-b4bf-16b29f7f1658",
419
+ "metadata": {},
420
+ "outputs": [
421
+ {
422
+ "name": "stderr",
423
+ "output_type": "stream",
424
+ "text": [
425
+ "No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.\n"
426
+ ]
427
+ }
428
+ ],
429
+ "source": [
430
+ "from trl import SFTTrainer\n",
431
+ "from transformers import TrainingArguments\n",
432
+ "\n",
433
+ "\n",
434
+ "# Training Arguments\n",
435
+ "training_arguments = TrainingArguments(\n",
436
+ " output_dir=\"output\",\n",
437
+ " per_device_train_batch_size=1,\n",
438
+ " per_device_eval_batch_size=1,\n",
439
+ " gradient_accumulation_steps=2,\n",
440
+ " optim=\"paged_adamw_32bit\",\n",
441
+ " num_train_epochs=1,\n",
442
+ " logging_steps=0.2,\n",
443
+ " warmup_steps=10,\n",
444
+ " logging_strategy=\"steps\",\n",
445
+ " learning_rate=2e-4,\n",
446
+ " fp16=False,\n",
447
+ " bf16=False,\n",
448
+ " group_by_length=True,\n",
449
+ " report_to=\"none\"\n",
450
+ ")\n",
451
+ "\n",
452
+ "# Initialize the Trainer\n",
453
+ "trainer = SFTTrainer(\n",
454
+ " model=model,\n",
455
+ " args=training_arguments,\n",
456
+ " train_dataset=dataset,\n",
457
+ " peft_config=peft_config,\n",
458
+ " data_collator=data_collator,\n",
459
+ ")"
460
+ ]
461
+ },
462
+ {
463
+ "cell_type": "markdown",
464
+ "id": "aaf6a32b-729c-4785-9d1b-9b8c07e63a58",
465
+ "metadata": {},
466
+ "source": [
467
+ "## Model training"
468
+ ]
469
+ },
470
+ {
471
+ "cell_type": "code",
472
+ "execution_count": 13,
473
+ "id": "eda9476a-2f35-4d7f-b08b-355f9a93659d",
474
+ "metadata": {},
475
+ "outputs": [
476
+ {
477
+ "data": {
478
+ "text/html": [
479
+ "\n",
480
+ " <div>\n",
481
+ " \n",
482
+ " <progress value='250' max='250' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
483
+ " [250/250 06:48, Epoch 1/1]\n",
484
+ " </div>\n",
485
+ " <table border=\"1\" class=\"dataframe\">\n",
486
+ " <thead>\n",
487
+ " <tr style=\"text-align: left;\">\n",
488
+ " <th>Step</th>\n",
489
+ " <th>Training Loss</th>\n",
490
+ " </tr>\n",
491
+ " </thead>\n",
492
+ " <tbody>\n",
493
+ " <tr>\n",
494
+ " <td>50</td>\n",
495
+ " <td>2.918900</td>\n",
496
+ " </tr>\n",
497
+ " <tr>\n",
498
+ " <td>100</td>\n",
499
+ " <td>2.391100</td>\n",
500
+ " </tr>\n",
501
+ " <tr>\n",
502
+ " <td>150</td>\n",
503
+ " <td>2.437700</td>\n",
504
+ " </tr>\n",
505
+ " <tr>\n",
506
+ " <td>200</td>\n",
507
+ " <td>2.412300</td>\n",
508
+ " </tr>\n",
509
+ " <tr>\n",
510
+ " <td>250</td>\n",
511
+ " <td>2.480700</td>\n",
512
+ " </tr>\n",
513
+ " </tbody>\n",
514
+ "</table><p>"
515
+ ],
516
+ "text/plain": [
517
+ "<IPython.core.display.HTML object>"
518
+ ]
519
+ },
520
+ "metadata": {},
521
+ "output_type": "display_data"
522
+ },
523
+ {
524
+ "data": {
525
+ "text/plain": [
526
+ "TrainOutput(global_step=250, training_loss=2.528125, metrics={'train_runtime': 410.7281, 'train_samples_per_second': 1.217, 'train_steps_per_second': 0.609, 'total_flos': 2.2490325698347315e+17, 'train_loss': 2.528125})"
527
+ ]
528
+ },
529
+ "execution_count": 13,
530
+ "metadata": {},
531
+ "output_type": "execute_result"
532
+ }
533
+ ],
534
+ "source": [
535
+ "trainer.train()"
536
+ ]
537
+ },
538
+ {
539
+ "cell_type": "markdown",
540
+ "id": "7a624af0-0cfe-4932-b149-93a286d519c1",
541
+ "metadata": {},
542
+ "source": [
543
+ "## Model inference after fine-tuning"
544
+ ]
545
+ },
546
+ {
547
+ "cell_type": "code",
548
+ "execution_count": 14,
549
+ "id": "b65e6fca-8f46-4bfc-a2d0-cf411cb38120",
550
+ "metadata": {},
551
+ "outputs": [
552
+ {
553
+ "name": "stdout",
554
+ "output_type": "stream",
555
+ "text": [
556
+ "\n",
557
+ "<think>assistant\n",
558
+ "\n",
559
+ "<think>\n",
560
+ "Alright, let's think about what's going on with this 61-year-old woman. She's experiencing involuntary urine loss when she coughs or sneezes, which is classic for stress urinary incontinence. But there's no leakage at night, which is interesting. That detail might help us rule out other types of incontinence, like overactive bladder.\n",
561
+ "\n",
562
+ "Now, let's think about the tests she's had. The gynecological exam and the Q-tip test are usually done to check for any anatomical issues that could be causing her symptoms. If the Q-tip test is positive, it means there's some mobility in the urethra, which can be a sign of stress incontinence.\n",
563
+ "\n",
564
+ "Okay, so we've got stress urinary incontinence in mind. What does cystometry tell us? It's a test that looks at how well the bladder functions by measuring pressure and volume during filling and voiding. \n",
565
+ "\n",
566
+ "For someone with stress incontinence, the cystometry would probably show that her bladder is doing its job properly. The residual volume, which is the amount of urine left in the bladder after she pees, should be normal. This means her bladder is emptying well enough.\n",
567
+ "\n",
568
+ "Also, with stress incontinence, you wouldn't expect to see abnormal detrusor contractions. These are like muscle spasms in the bladder that can cause urgency and urge incontinence. But in stress incontinence, the problem is more about the urethral support and pressure control during activities, not the bladder muscle itself.\n",
569
+ "\n",
570
+ "So, putting it all together, the cystometry should show a normal residual volume and no abnormal detrusor contractions. This fits perfectly with what we know about stress urinary incontinence. Yep, that's what I'd expect to see in this case.\n",
571
+ "</think>\n",
572
+ "In this scenario, the cystometry would most likely reveal a normal residual volume and no abnormal detrusor contractions. This is because stress urinary incontinence, as indicated by the patient's symptoms and the positive Q-tip test, is primarily related to issues with urethral support and pressure control during activities, not with the bladder muscle's function or emptying ability.\n"
573
+ ]
574
+ }
575
+ ],
576
+ "source": [
577
+ "question = dataset[0]['Question']\n",
578
+ "inputs = tokenizer(\n",
579
+ " [prompt_style.format(question, \"\") + tokenizer.eos_token],\n",
580
+ " return_tensors=\"pt\"\n",
581
+ ").to(\"cuda\")\n",
582
+ "\n",
583
+ "outputs = model.generate(\n",
584
+ " input_ids=inputs.input_ids,\n",
585
+ " attention_mask=inputs.attention_mask,\n",
586
+ " max_new_tokens=1200,\n",
587
+ " eos_token_id=tokenizer.eos_token_id,\n",
588
+ " use_cache=True,\n",
589
+ ")\n",
590
+ "response = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
591
+ "print(response[0].split(\"### Response:\")[1])"
592
+ ]
593
+ },
594
+ {
595
+ "cell_type": "code",
596
+ "execution_count": 15,
597
+ "id": "b05e005d-f4b1-4af3-896b-70871cfbc600",
598
+ "metadata": {},
599
+ "outputs": [
600
+ {
601
+ "name": "stdout",
602
+ "output_type": "stream",
603
+ "text": [
604
+ "\n",
605
+ "<think>assistant\n",
606
+ "\n",
607
+ "<think>\n",
608
+ "Alright, let's break this down. So, we have a 42-year-old guy who's just recovered from pneumonia, and his T3 levels are low. Now, his TSH is slightly elevated at 4.7 µU/mL, which isn't too high but is definitely not normal. T4 is at 6 µg/dL, which is kind of in the middle range, and T3 is 68 ng/dL, which is low. Hmm, these numbers make me think of a condition called sick euthyroid syndrome. It's like when the body gets sick, the thyroid function tests can get all wonky, even if the thyroid itself is working fine.\n",
609
+ "\n",
610
+ "In this syndrome, T3 usually drops first, and T4 might stay normal or even rise a bit. This makes sense here because his T3 is low, but T4 is still around 6 µg/dL. Now, the TSH is slightly elevated, which is a bit puzzling, but in sick euthyroid syndrome, TSH can be normal or even slightly elevated. So, this fits.\n",
611
+ "\n",
612
+ "But wait, what about the other hormones? In sick euthyroid syndrome, we often see an increase in reverse T3 (rT3). This is because the body might convert T4 to rT3 instead of T3 when it's stressed. So, if I were to guess, I'd say the reverse T3 level is probably elevated in this case.\n",
613
+ "\n",
614
+ "This all makes sense because the clinical picture and lab results are lining up nicely with sick euthyroid syndrome. So, if I had to pick an additional hormone that's likely elevated, it would be the reverse T3. Yeah, that seems to fit the scenario perfectly.\n",
615
+ "</think>\n",
616
+ "In this scenario, considering the clinical context of a 42-year-old man recovering from pneumonia with decreased T3 levels and slightly elevated TSH, along with normal T4 levels, the likely condition is sick euthyroid syndrome. In this condition, the body often converts T4 to reverse T3 (rT3) instead of T3 during stress or illness. Therefore, the additional hormone level that is likely to be elevated in this patient is reverse T3 (rT3).\n"
617
+ ]
618
+ }
619
+ ],
620
+ "source": [
621
+ "question = dataset[10]['Question']\n",
622
+ "inputs = tokenizer(\n",
623
+ " [prompt_style.format(question, \"\") + tokenizer.eos_token],\n",
624
+ " return_tensors=\"pt\"\n",
625
+ ").to(\"cuda\")\n",
626
+ "\n",
627
+ "outputs = model.generate(\n",
628
+ " input_ids=inputs.input_ids,\n",
629
+ " attention_mask=inputs.attention_mask,\n",
630
+ " max_new_tokens=1200,\n",
631
+ " eos_token_id=tokenizer.eos_token_id,\n",
632
+ " use_cache=True,\n",
633
+ ")\n",
634
+ "response = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
635
+ "print(response[0].split(\"### Response:\")[1])"
636
+ ]
637
+ },
638
+ {
639
+ "cell_type": "markdown",
640
+ "id": "7791b473-13fd-485f-a40a-21ac12bbe497",
641
+ "metadata": {},
642
+ "source": [
643
+ "## Saving the model"
644
+ ]
645
+ },
646
+ {
647
+ "cell_type": "code",
648
+ "execution_count": null,
649
+ "id": "b5dd4891-47f7-485b-ae31-b26e8e9c9565",
650
+ "metadata": {},
651
+ "outputs": [
652
+ {
653
+ "data": {
654
+ "application/vnd.jupyter.widget-view+json": {
655
+ "model_id": "bdd0bfa862a9471584203114e5953b68",
656
+ "version_major": 2,
657
+ "version_minor": 0
658
+ },
659
+ "text/plain": [
660
+ "adapter_model.safetensors: 0%| | 0.00/992M [00:00<?, ?B/s]"
661
+ ]
662
+ },
663
+ "metadata": {},
664
+ "output_type": "display_data"
665
+ }
666
+ ],
667
+ "source": [
668
+ "model.push_to_hub(\"Llama-4-Scout-17B-16E-Instruct-Medical-ChatBot\")\n",
669
+ "tokenizer.push_to_hub(\"Llama-4-Scout-17B-16E-Instruct-Medical-ChatBot\")"
670
+ ]
671
+ },
672
+ {
673
+ "cell_type": "code",
674
+ "execution_count": null,
675
+ "id": "4890599b-7325-4ca5-b0b2-402e8c38ea50",
676
+ "metadata": {},
677
+ "outputs": [],
678
+ "source": []
679
+ }
680
+ ],
681
+ "metadata": {
682
+ "kernelspec": {
683
+ "display_name": "Python 3 (ipykernel)",
684
+ "language": "python",
685
+ "name": "python3"
686
+ },
687
+ "language_info": {
688
+ "codemirror_mode": {
689
+ "name": "ipython",
690
+ "version": 3
691
+ },
692
+ "file_extension": ".py",
693
+ "mimetype": "text/x-python",
694
+ "name": "python",
695
+ "nbconvert_exporter": "python",
696
+ "pygments_lexer": "ipython3",
697
+ "version": "3.11.11"
698
+ }
699
+ },
700
+ "nbformat": 4,
701
+ "nbformat_minor": 5
702
+ }