jdaddyalbs commited on
Commit
46576d8
·
verified ·
1 Parent(s): 7436348

uploaded notebook to train this model

Browse files
Files changed (1) hide show
  1. trainer-v1.ipynb +337 -0
trainer-v1.ipynb ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "bdb9c28f-d4a7-49e1-9d5e-74a878e2edc3",
7
+ "metadata": {
8
+ "scrolled": true
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "from unsloth import FastLanguageModel, is_bfloat16_supported\n",
13
+ "from trl import SFTConfig, SFTTrainer\n",
14
+ "from datasets import load_dataset, Dataset\n",
15
+ "from unsloth import FastLanguageModel, is_bfloat16_supported\n",
16
+ "import torch\n",
17
+ "from mcp.types import Tool, ToolAnnotations\n",
18
+ "import os \n",
19
+ "import wandb\n",
20
+ "import torch\n",
21
+ "import json\n",
22
+ "from transformers import DataCollatorForSeq2Seq\n",
23
+ "from unsloth.chat_templates import train_on_responses_only\n",
24
+ "from urllib.parse import urlencode\n"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "id": "6ca3b85c-fd3d-4a2f-a239-2a1949d608fc",
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "# token used to upload models\n",
35
+ "#HF_TOKEN = \"\"\n",
36
+ "# wandb optional for logging training data\n",
37
+ "#os.environ['WANDB_PROJECT'] = \"\"\n",
38
+ "#os.environ['WANDB_API_KEY'] = \"\"\n",
39
+ "\n",
40
+ "#wandb.login()"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "id": "1da19b51-f367-44c4-b4ad-f9dbe208ae56",
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "max_seq_length = 18000 # Can increase for longer reasoning traces, works with 32GB GPU, may need to reduce to avoid out-of-memory errors\n",
51
+ "lora_rank = 32 # Larger rank = smarter, but slower\n",
52
+ "\n",
53
+ "\n",
54
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
55
+ " model_name = \"unsloth/Qwen3-4B-bnb-4bit\",\n",
56
+ " #model_name = \"./qwen3-sft/checkpoint-765\", # uncomment to load a local checkpoint like this example\n",
57
+ " max_seq_length = max_seq_length,\n",
58
+ " load_in_4bit = True, # False for LoRA 16bit\n",
59
+ " fast_inference = False,\n",
60
+ " max_lora_rank = lora_rank,\n",
61
+ " gpu_memory_utilization = 0.5, # Reduce if out of memory\n",
62
+ ")\n",
63
+ "\n",
64
+ "\n",
65
+ "model = FastLanguageModel.get_peft_model(\n",
66
+ " model,\n",
67
+ " r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
68
+ " target_modules = [\n",
69
+ " \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
70
+ " \"gate_proj\", \"up_proj\", \"down_proj\",\n",
71
+ " ], # Remove QKVO if out of memory\n",
72
+ " lora_alpha = lora_rank*2,\n",
73
+ " use_gradient_checkpointing = \"unsloth\", # Enable long context finetuning\n",
74
+ " random_state = 3407,\n",
75
+ ")\n"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "id": "d60f8797-20b3-44a8-9068-156488863427",
82
+ "metadata": {},
83
+ "outputs": [],
84
+ "source": [
85
+ "# may not be necessary, unsloth version may already have this template in the tokenizer\n",
86
+ "from unsloth.chat_templates import get_chat_template\n",
87
+ "\n",
88
+ "tokenizer = get_chat_template(\n",
89
+ " tokenizer,\n",
90
+ " chat_template = \"qwen-3\",\n",
91
+ ")"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": null,
97
+ "id": "a0754d2a-baf7-406b-8450-273c668a0f38",
98
+ "metadata": {
99
+ "scrolled": true
100
+ },
101
+ "outputs": [],
102
+ "source": [
103
+ "# see dataset repo for details on dataset generation\n",
104
+ "dataset = load_dataset(\"jdaddyalbs/playwright-mcp-toolcalling\", split=\"unprocessed\").remove_columns([\"text\"])\n",
105
+ "# alternately you can load the preprocessed train test files directly (can be used with other qwen3 models with same chat template)\n",
106
+ "# if loading the files below, you can skip ahead to the \"trainer = SFTTrainer(....\" part\n",
107
+ "#eval_dataset = load_dataset(\"jdaddyalbs/playwright-mcp-toolcalling\", data_files=\"data/test.parquet\")['train']\n",
108
+ "#train_dataset = load_dataset(\"jdaddyalbs/playwright-mcp-toolcalling\", data_files=\"data/train.parquet\")['train']"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": null,
114
+ "id": "2e055354-d080-4a71-ae20-227a88c79fa7",
115
+ "metadata": {},
116
+ "outputs": [],
117
+ "source": [
118
+ "# tools from microsoft's playwright-mcp v0.0.31\n",
119
+ "tools = load_dataset(\"jdaddyalbs/playwright-mcp-toolcalling\",data_files=\"tools.txt\")\n",
120
+ "tools = eval(\"\".join([tools['train']['text'][i] for i in range(len(tools['train']['text']))]))"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "id": "263ecd0f-41b0-4cf2-bb47-61d6d0725960",
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "# convert to valid json\n",
131
+ "tools_json = [\n",
132
+ " {\n",
133
+ " \"type\":\"function\",\n",
134
+ " \"function\": {\n",
135
+ " \"name\": tool.name,\n",
136
+ " \"description\": tool.description,\n",
137
+ " \"parameters\": tool.inputSchema\n",
138
+ " }\n",
139
+ " } for tool in tools\n",
140
+ "]"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": null,
146
+ "id": "bb4290d2-8151-43c4-a9a0-28600c85a12b",
147
+ "metadata": {},
148
+ "outputs": [],
149
+ "source": [
150
+ "# convert messages to correct format for model using chat template\n",
151
+ "def apply_template(messages):\n",
152
+ " return tokenizer.apply_chat_template(\n",
153
+ " messages,\n",
154
+ " tools=tools_json,\n",
155
+ " tokenize=False,\n",
156
+ " add_generation_prompt=False,\n",
157
+ " enable_thinking=True\n",
158
+ " ) "
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "id": "a6ec133c-8c7d-469d-8380-fb6cf54a7264",
165
+ "metadata": {
166
+ "scrolled": true
167
+ },
168
+ "outputs": [],
169
+ "source": [
170
+ "messages = []\n",
171
+ "for i in range(len(dataset['messages'])):\n",
172
+ " msgs = [json.loads(msg) for msg in dataset['messages'][i]]\n",
173
+ " messages.append(apply_template(msgs))"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": null,
179
+ "id": "71a5b8d5-0299-4c36-af9d-c028ecccf7cd",
180
+ "metadata": {},
181
+ "outputs": [],
182
+ "source": [
183
+ "dataset = dataset.add_column(\"text\",messages)"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": null,
189
+ "id": "2f2bb399-bf1f-43a8-9f15-1efef2c35a80",
190
+ "metadata": {},
191
+ "outputs": [],
192
+ "source": [
193
+ "# I want to encourage the model to use tools, so I only give examples where tools are used\n",
194
+ "# May or may not be helpful\n",
195
+ "dataset = dataset.filter(lambda x: x[\"num_tools\"] > 0)\n",
196
+ "dataset = dataset.filter(lambda x: x[\"llm_match\"])"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "code",
201
+ "execution_count": null,
202
+ "id": "4a0ccd1a-d217-4ece-8773-73117ab355dc",
203
+ "metadata": {},
204
+ "outputs": [],
205
+ "source": [
206
+ "# keep seed constant to get repeatable split\n",
207
+ "ds = dataset.train_test_split(test_size = 0.1, seed=42)\n",
208
+ "train_dataset = ds['train']\n",
209
+ "eval_dataset = ds['test']"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": null,
215
+ "id": "659a2864-fa2f-4abc-b4ae-6d267f7c1bd9",
216
+ "metadata": {},
217
+ "outputs": [],
218
+ "source": [
219
+ "trainer = SFTTrainer(\n",
220
+ " model = model,\n",
221
+ " tokenizer = tokenizer,\n",
222
+ " train_dataset = train_dataset,\n",
223
+ " eval_dataset = eval_dataset, # Can set up evaluation!\n",
224
+ " data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),\n",
225
+ " args = SFTConfig(\n",
226
+ " dataset_text_field = \"text\",\n",
227
+ " per_device_train_batch_size = 1, # bigger batches takes up too much GPU memory\n",
228
+ " gradient_accumulation_steps = 4, # 1 gradient update every 4 samples, higher should make training more stable but take longer\n",
229
+ " warmup_steps = 5,\n",
230
+ " num_train_epochs = 1, # Set this for 1 full training run.\n",
231
+ " learning_rate = 2e-4, # Reduce to 2e-5 for long training runs\n",
232
+ " logging_steps = 1,\n",
233
+ " optim = \"adamw_8bit\",\n",
234
+ " weight_decay = 0.01,\n",
235
+ " lr_scheduler_type = \"linear\",\n",
236
+ " seed = 3407,\n",
237
+ " #report_to = \"wandb\", # Use this for WandB, comment out if not using\n",
238
+ " output_dir='qwen3-sft',\n",
239
+ " dataset_num_proc=2,\n",
240
+ " eval_steps=50,\n",
241
+ " fp16_full_eval = True,\n",
242
+ " per_device_eval_batch_size = 1,\n",
243
+ " eval_accumulation_steps = 1,\n",
244
+ " eval_strategy = \"steps\",\n",
245
+ " ),\n",
246
+ ")"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "code",
251
+ "execution_count": null,
252
+ "id": "8e9d1051-7813-4570-9fcf-183d70e845bb",
253
+ "metadata": {},
254
+ "outputs": [],
255
+ "source": [
256
+ "# start the finetuning process\n",
257
+ "trainer_stats = trainer.train(resume_from_checkpoint=False)"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "execution_count": null,
263
+ "id": "7e8d7a36-a5b6-452c-9fcf-5bb4325bd5d0",
264
+ "metadata": {
265
+ "scrolled": true
266
+ },
267
+ "outputs": [],
268
+ "source": [
269
+ "#model.push_to_hub_gguf(\"jdaddyalbs/qwen3_sft_playwright_gguf\", tokenizer,token=HF_TOKEN, quantization_method='q8_0')"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "code",
274
+ "execution_count": null,
275
+ "id": "ad0eb650-ea62-4a3a-b968-88d6f792dc28",
276
+ "metadata": {
277
+ "scrolled": true
278
+ },
279
+ "outputs": [],
280
+ "source": [
281
+ "#model.push_to_hub_merged(\"jdaddyalbs/qwen3_sft_playwright\",tokenizer,token=HF_TOKEN,save_method=\"merged_16bit\")"
282
+ ]
283
+ },
284
+ {
285
+ "cell_type": "code",
286
+ "execution_count": null,
287
+ "id": "40b285d6-49ff-4d8f-8220-7ad7e3ef4fa9",
288
+ "metadata": {
289
+ "scrolled": true
290
+ },
291
+ "outputs": [],
292
+ "source": [
293
+ "# example evaluate a single sample\n",
294
+ "idx = 51\n",
295
+ "print(eval_dataset[idx]['true_answer'])\n",
296
+ "print(eval_dataset[idx]['answer'])\n",
297
+ "\n",
298
+ "text = tokenizer.apply_chat_template(\n",
299
+ " eval_dataset[idx][\"evil_messages\"][:2],\n",
300
+ " tokenize = False,\n",
301
+ " tools=tools_json,\n",
302
+ " add_generation_prompt = True, # Must add for generation\n",
303
+ " enable_thinking = True,\n",
304
+ ")\n",
305
+ "\n",
306
+ "from transformers import TextStreamer\n",
307
+ "out = model.generate(\n",
308
+ " **tokenizer(text, return_tensors = \"pt\").to(\"cuda\"),\n",
309
+ " temperature = 0.0001, top_p = 0.95, top_k = 20, # For thinking\n",
310
+ " max_new_tokens = 2048,\n",
311
+ " streamer = TextStreamer(tokenizer, skip_prompt = False),\n",
312
+ ")"
313
+ ]
314
+ }
315
+ ],
316
+ "metadata": {
317
+ "kernelspec": {
318
+ "display_name": "Python 3 (ipykernel)",
319
+ "language": "python",
320
+ "name": "python3"
321
+ },
322
+ "language_info": {
323
+ "codemirror_mode": {
324
+ "name": "ipython",
325
+ "version": 3
326
+ },
327
+ "file_extension": ".py",
328
+ "mimetype": "text/x-python",
329
+ "name": "python",
330
+ "nbconvert_exporter": "python",
331
+ "pygments_lexer": "ipython3",
332
+ "version": "3.13.5"
333
+ }
334
+ },
335
+ "nbformat": 4,
336
+ "nbformat_minor": 5
337
+ }