uploaded notebook to train this model
Browse files- trainer-v1.ipynb +337 -0
trainer-v1.ipynb
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "bdb9c28f-d4a7-49e1-9d5e-74a878e2edc3",
|
7 |
+
"metadata": {
|
8 |
+
"scrolled": true
|
9 |
+
},
|
10 |
+
"outputs": [],
|
11 |
+
"source": [
|
12 |
+
"from unsloth import FastLanguageModel, is_bfloat16_supported\n",
|
13 |
+
"from trl import SFTConfig, SFTTrainer\n",
|
14 |
+
"from datasets import load_dataset, Dataset\n",
|
15 |
+
"from unsloth import FastLanguageModel, is_bfloat16_supported\n",
|
16 |
+
"import torch\n",
|
17 |
+
"from mcp.types import Tool, ToolAnnotations\n",
|
18 |
+
"import os \n",
|
19 |
+
"import wandb\n",
|
20 |
+
"import torch\n",
|
21 |
+
"import json\n",
|
22 |
+
"from transformers import DataCollatorForSeq2Seq\n",
|
23 |
+
"from unsloth.chat_templates import train_on_responses_only\n",
|
24 |
+
"from urllib.parse import urlencode\n"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"cell_type": "code",
|
29 |
+
"execution_count": null,
|
30 |
+
"id": "6ca3b85c-fd3d-4a2f-a239-2a1949d608fc",
|
31 |
+
"metadata": {},
|
32 |
+
"outputs": [],
|
33 |
+
"source": [
|
34 |
+
"# token used to upload models\n",
|
35 |
+
"#HF_TOKEN = \"\"\n",
|
36 |
+
"# wandb optional for logging training data\n",
|
37 |
+
"#os.environ['WANDB_PROJECT'] = \"\"\n",
|
38 |
+
"#os.environ['WANDB_API_KEY'] = \"\"\n",
|
39 |
+
"\n",
|
40 |
+
"#wandb.login()"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": null,
|
46 |
+
"id": "1da19b51-f367-44c4-b4ad-f9dbe208ae56",
|
47 |
+
"metadata": {},
|
48 |
+
"outputs": [],
|
49 |
+
"source": [
|
50 |
+
"max_seq_length = 18000 # Can increase for longer reasoning traces, works with 32GB GPU, may need to reduce to avoid out-of-memory errors\n",
|
51 |
+
"lora_rank = 32 # Larger rank = smarter, but slower\n",
|
52 |
+
"\n",
|
53 |
+
"\n",
|
54 |
+
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
55 |
+
" model_name = \"unsloth/Qwen3-4B-bnb-4bit\",\n",
|
56 |
+
" #model_name = \"./qwen3-sft/checkpoint-765\", # uncomment to load a local checkpoint like this example\n",
|
57 |
+
" max_seq_length = max_seq_length,\n",
|
58 |
+
" load_in_4bit = True, # False for LoRA 16bit\n",
|
59 |
+
" fast_inference = False,\n",
|
60 |
+
" max_lora_rank = lora_rank,\n",
|
61 |
+
" gpu_memory_utilization = 0.5, # Reduce if out of memory\n",
|
62 |
+
")\n",
|
63 |
+
"\n",
|
64 |
+
"\n",
|
65 |
+
"model = FastLanguageModel.get_peft_model(\n",
|
66 |
+
" model,\n",
|
67 |
+
" r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
|
68 |
+
" target_modules = [\n",
|
69 |
+
" \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
|
70 |
+
" \"gate_proj\", \"up_proj\", \"down_proj\",\n",
|
71 |
+
" ], # Remove QKVO if out of memory\n",
|
72 |
+
" lora_alpha = lora_rank*2,\n",
|
73 |
+
" use_gradient_checkpointing = \"unsloth\", # Enable long context finetuning\n",
|
74 |
+
" random_state = 3407,\n",
|
75 |
+
")\n"
|
76 |
+
]
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"cell_type": "code",
|
80 |
+
"execution_count": null,
|
81 |
+
"id": "d60f8797-20b3-44a8-9068-156488863427",
|
82 |
+
"metadata": {},
|
83 |
+
"outputs": [],
|
84 |
+
"source": [
|
85 |
+
"# may not be necessary, unsloth version may already have this template in the tokenizer\n",
|
86 |
+
"from unsloth.chat_templates import get_chat_template\n",
|
87 |
+
"\n",
|
88 |
+
"tokenizer = get_chat_template(\n",
|
89 |
+
" tokenizer,\n",
|
90 |
+
" chat_template = \"qwen-3\",\n",
|
91 |
+
")"
|
92 |
+
]
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"cell_type": "code",
|
96 |
+
"execution_count": null,
|
97 |
+
"id": "a0754d2a-baf7-406b-8450-273c668a0f38",
|
98 |
+
"metadata": {
|
99 |
+
"scrolled": true
|
100 |
+
},
|
101 |
+
"outputs": [],
|
102 |
+
"source": [
|
103 |
+
"# see dataset repo for details on dataset generation\n",
|
104 |
+
"dataset = load_dataset(\"jdaddyalbs/playwright-mcp-toolcalling\", split=\"unprocessed\").remove_columns([\"text\"])\n",
|
105 |
+
"# alternately you can load the preprocessed train test files directly (can be used with other qwen3 models with same chat template)\n",
|
106 |
+
"# if loading the files below, you can skip ahead to the \"trainer = SFTTrainer(....\" part\n",
|
107 |
+
"#eval_dataset = load_dataset(\"jdaddyalbs/playwright-mcp-toolcalling\", data_files=\"data/test.parquet\")['train']\n",
|
108 |
+
"#train_dataset = load_dataset(\"jdaddyalbs/playwright-mcp-toolcalling\", data_files=\"data/train.parquet\")['train']"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"cell_type": "code",
|
113 |
+
"execution_count": null,
|
114 |
+
"id": "2e055354-d080-4a71-ae20-227a88c79fa7",
|
115 |
+
"metadata": {},
|
116 |
+
"outputs": [],
|
117 |
+
"source": [
|
118 |
+
"# tools from microsoft's playwright-mcp v0.0.31\n",
|
119 |
+
"tools = load_dataset(\"jdaddyalbs/playwright-mcp-toolcalling\",data_files=\"tools.txt\")\n",
|
120 |
+
"tools = eval(\"\".join([tools['train']['text'][i] for i in range(len(tools['train']['text']))]))"
|
121 |
+
]
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"cell_type": "code",
|
125 |
+
"execution_count": null,
|
126 |
+
"id": "263ecd0f-41b0-4cf2-bb47-61d6d0725960",
|
127 |
+
"metadata": {},
|
128 |
+
"outputs": [],
|
129 |
+
"source": [
|
130 |
+
"# convert to valid json\n",
|
131 |
+
"tools_json = [\n",
|
132 |
+
" {\n",
|
133 |
+
" \"type\":\"function\",\n",
|
134 |
+
" \"function\": {\n",
|
135 |
+
" \"name\": tool.name,\n",
|
136 |
+
" \"description\": tool.description,\n",
|
137 |
+
" \"parameters\": tool.inputSchema\n",
|
138 |
+
" }\n",
|
139 |
+
" } for tool in tools\n",
|
140 |
+
"]"
|
141 |
+
]
|
142 |
+
},
|
143 |
+
{
|
144 |
+
"cell_type": "code",
|
145 |
+
"execution_count": null,
|
146 |
+
"id": "bb4290d2-8151-43c4-a9a0-28600c85a12b",
|
147 |
+
"metadata": {},
|
148 |
+
"outputs": [],
|
149 |
+
"source": [
|
150 |
+
"# convert messages to correct format for model using chat template\n",
|
151 |
+
"def apply_template(messages):\n",
|
152 |
+
" return tokenizer.apply_chat_template(\n",
|
153 |
+
" messages,\n",
|
154 |
+
" tools=tools_json,\n",
|
155 |
+
" tokenize=False,\n",
|
156 |
+
" add_generation_prompt=False,\n",
|
157 |
+
" enable_thinking=True\n",
|
158 |
+
" ) "
|
159 |
+
]
|
160 |
+
},
|
161 |
+
{
|
162 |
+
"cell_type": "code",
|
163 |
+
"execution_count": null,
|
164 |
+
"id": "a6ec133c-8c7d-469d-8380-fb6cf54a7264",
|
165 |
+
"metadata": {
|
166 |
+
"scrolled": true
|
167 |
+
},
|
168 |
+
"outputs": [],
|
169 |
+
"source": [
|
170 |
+
"messages = []\n",
|
171 |
+
"for i in range(len(dataset['messages'])):\n",
|
172 |
+
" msgs = [json.loads(msg) for msg in dataset['messages'][i]]\n",
|
173 |
+
" messages.append(apply_template(msgs))"
|
174 |
+
]
|
175 |
+
},
|
176 |
+
{
|
177 |
+
"cell_type": "code",
|
178 |
+
"execution_count": null,
|
179 |
+
"id": "71a5b8d5-0299-4c36-af9d-c028ecccf7cd",
|
180 |
+
"metadata": {},
|
181 |
+
"outputs": [],
|
182 |
+
"source": [
|
183 |
+
"dataset = dataset.add_column(\"text\",messages)"
|
184 |
+
]
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"cell_type": "code",
|
188 |
+
"execution_count": null,
|
189 |
+
"id": "2f2bb399-bf1f-43a8-9f15-1efef2c35a80",
|
190 |
+
"metadata": {},
|
191 |
+
"outputs": [],
|
192 |
+
"source": [
|
193 |
+
"# I want to encourage the model to use tools, so I only give examples where tools are used\n",
|
194 |
+
"# May or may not be helpful\n",
|
195 |
+
"dataset = dataset.filter(lambda x: x[\"num_tools\"] > 0)\n",
|
196 |
+
"dataset = dataset.filter(lambda x: x[\"llm_match\"])"
|
197 |
+
]
|
198 |
+
},
|
199 |
+
{
|
200 |
+
"cell_type": "code",
|
201 |
+
"execution_count": null,
|
202 |
+
"id": "4a0ccd1a-d217-4ece-8773-73117ab355dc",
|
203 |
+
"metadata": {},
|
204 |
+
"outputs": [],
|
205 |
+
"source": [
|
206 |
+
"# keep seed constant to get repeatable split\n",
|
207 |
+
"ds = dataset.train_test_split(test_size = 0.1, seed=42)\n",
|
208 |
+
"train_dataset = ds['train']\n",
|
209 |
+
"eval_dataset = ds['test']"
|
210 |
+
]
|
211 |
+
},
|
212 |
+
{
|
213 |
+
"cell_type": "code",
|
214 |
+
"execution_count": null,
|
215 |
+
"id": "659a2864-fa2f-4abc-b4ae-6d267f7c1bd9",
|
216 |
+
"metadata": {},
|
217 |
+
"outputs": [],
|
218 |
+
"source": [
|
219 |
+
"trainer = SFTTrainer(\n",
|
220 |
+
" model = model,\n",
|
221 |
+
" tokenizer = tokenizer,\n",
|
222 |
+
" train_dataset = train_dataset,\n",
|
223 |
+
" eval_dataset = eval_dataset, # Can set up evaluation!\n",
|
224 |
+
" data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),\n",
|
225 |
+
" args = SFTConfig(\n",
|
226 |
+
" dataset_text_field = \"text\",\n",
|
227 |
+
" per_device_train_batch_size = 1, # bigger batches takes up too much GPU memory\n",
|
228 |
+
" gradient_accumulation_steps = 4, # 1 gradient update every 4 samples, higher should make training more stable but take longer\n",
|
229 |
+
" warmup_steps = 5,\n",
|
230 |
+
" num_train_epochs = 1, # Set this for 1 full training run.\n",
|
231 |
+
" learning_rate = 2e-4, # Reduce to 2e-5 for long training runs\n",
|
232 |
+
" logging_steps = 1,\n",
|
233 |
+
" optim = \"adamw_8bit\",\n",
|
234 |
+
" weight_decay = 0.01,\n",
|
235 |
+
" lr_scheduler_type = \"linear\",\n",
|
236 |
+
" seed = 3407,\n",
|
237 |
+
" #report_to = \"wandb\", # Use this for WandB, comment out if not using\n",
|
238 |
+
" output_dir='qwen3-sft',\n",
|
239 |
+
" dataset_num_proc=2,\n",
|
240 |
+
" eval_steps=50,\n",
|
241 |
+
" fp16_full_eval = True,\n",
|
242 |
+
" per_device_eval_batch_size = 1,\n",
|
243 |
+
" eval_accumulation_steps = 1,\n",
|
244 |
+
" eval_strategy = \"steps\",\n",
|
245 |
+
" ),\n",
|
246 |
+
")"
|
247 |
+
]
|
248 |
+
},
|
249 |
+
{
|
250 |
+
"cell_type": "code",
|
251 |
+
"execution_count": null,
|
252 |
+
"id": "8e9d1051-7813-4570-9fcf-183d70e845bb",
|
253 |
+
"metadata": {},
|
254 |
+
"outputs": [],
|
255 |
+
"source": [
|
256 |
+
"# start the finetuning process\n",
|
257 |
+
"trainer_stats = trainer.train(resume_from_checkpoint=False)"
|
258 |
+
]
|
259 |
+
},
|
260 |
+
{
|
261 |
+
"cell_type": "code",
|
262 |
+
"execution_count": null,
|
263 |
+
"id": "7e8d7a36-a5b6-452c-9fcf-5bb4325bd5d0",
|
264 |
+
"metadata": {
|
265 |
+
"scrolled": true
|
266 |
+
},
|
267 |
+
"outputs": [],
|
268 |
+
"source": [
|
269 |
+
"#model.push_to_hub_gguf(\"jdaddyalbs/qwen3_sft_playwright_gguf\", tokenizer,token=HF_TOKEN, quantization_method='q8_0')"
|
270 |
+
]
|
271 |
+
},
|
272 |
+
{
|
273 |
+
"cell_type": "code",
|
274 |
+
"execution_count": null,
|
275 |
+
"id": "ad0eb650-ea62-4a3a-b968-88d6f792dc28",
|
276 |
+
"metadata": {
|
277 |
+
"scrolled": true
|
278 |
+
},
|
279 |
+
"outputs": [],
|
280 |
+
"source": [
|
281 |
+
"#model.push_to_hub_merged(\"jdaddyalbs/qwen3_sft_playwright\",tokenizer,token=HF_TOKEN,save_method=\"merged_16bit\")"
|
282 |
+
]
|
283 |
+
},
|
284 |
+
{
|
285 |
+
"cell_type": "code",
|
286 |
+
"execution_count": null,
|
287 |
+
"id": "40b285d6-49ff-4d8f-8220-7ad7e3ef4fa9",
|
288 |
+
"metadata": {
|
289 |
+
"scrolled": true
|
290 |
+
},
|
291 |
+
"outputs": [],
|
292 |
+
"source": [
|
293 |
+
"# example evaluate a single sample\n",
|
294 |
+
"idx = 51\n",
|
295 |
+
"print(eval_dataset[idx]['true_answer'])\n",
|
296 |
+
"print(eval_dataset[idx]['answer'])\n",
|
297 |
+
"\n",
|
298 |
+
"text = tokenizer.apply_chat_template(\n",
|
299 |
+
" eval_dataset[idx][\"evil_messages\"][:2],\n",
|
300 |
+
" tokenize = False,\n",
|
301 |
+
" tools=tools_json,\n",
|
302 |
+
" add_generation_prompt = True, # Must add for generation\n",
|
303 |
+
" enable_thinking = True,\n",
|
304 |
+
")\n",
|
305 |
+
"\n",
|
306 |
+
"from transformers import TextStreamer\n",
|
307 |
+
"out = model.generate(\n",
|
308 |
+
" **tokenizer(text, return_tensors = \"pt\").to(\"cuda\"),\n",
|
309 |
+
" temperature = 0.0001, top_p = 0.95, top_k = 20, # For thinking\n",
|
310 |
+
" max_new_tokens = 2048,\n",
|
311 |
+
" streamer = TextStreamer(tokenizer, skip_prompt = False),\n",
|
312 |
+
")"
|
313 |
+
]
|
314 |
+
}
|
315 |
+
],
|
316 |
+
"metadata": {
|
317 |
+
"kernelspec": {
|
318 |
+
"display_name": "Python 3 (ipykernel)",
|
319 |
+
"language": "python",
|
320 |
+
"name": "python3"
|
321 |
+
},
|
322 |
+
"language_info": {
|
323 |
+
"codemirror_mode": {
|
324 |
+
"name": "ipython",
|
325 |
+
"version": 3
|
326 |
+
},
|
327 |
+
"file_extension": ".py",
|
328 |
+
"mimetype": "text/x-python",
|
329 |
+
"name": "python",
|
330 |
+
"nbconvert_exporter": "python",
|
331 |
+
"pygments_lexer": "ipython3",
|
332 |
+
"version": "3.13.5"
|
333 |
+
}
|
334 |
+
},
|
335 |
+
"nbformat": 4,
|
336 |
+
"nbformat_minor": 5
|
337 |
+
}
|