omark807 commited on
Commit
e1c12b5
·
verified ·
1 Parent(s): 8b46e7c

Upload web_a11y_model.py

Browse files
Files changed (1) hide show
  1. web_a11y_model.py +340 -0
web_a11y_model.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # gpu_finetune.py
3
+
4
+ import os
5
+ import sys
6
+ import torch
7
+ import logging
8
+ from pathlib import Path
9
+ import traceback
10
+
11
+ # Set up logging
12
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
13
+ logger = logging.getLogger(__name__)
14
+
15
+ def check_environment():
16
+ """Check and report system environment"""
17
+ logger.info("=== Environment Check ===")
18
+ logger.info(f"Python version: {sys.version}")
19
+ logger.info(f"PyTorch version: {torch.__version__}")
20
+ logger.info(f"CUDA available: {torch.cuda.is_available()}")
21
+ if torch.cuda.is_available():
22
+ logger.info(f"CUDA version: {torch.version.cuda}")
23
+ logger.info(f"GPU count: {torch.cuda.device_count()}")
24
+ for i in range(torch.cuda.device_count()):
25
+ logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
26
+ logger.info(f"GPU {i} memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.1f} GB")
27
+
28
+ def main():
29
+ try:
30
+ check_environment()
31
+ logger.info("Importing required packages...")
32
+
33
+ try:
34
+ from datasets import load_dataset
35
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
36
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
37
+ from trl import SFTTrainer
38
+ logger.info("✓ All transformers packages imported successfully")
39
+ except ImportError as e:
40
+ logger.error(f"Failed to import transformers packages: {e}")
41
+ logger.error("Please ensure all packages are installed: pip install transformers datasets peft trl")
42
+ sys.exit(1)
43
+
44
+ # --- Configuration ---
45
+ MODEL_ID = "google/gemma-3-1b-it"
46
+ OUTPUT_DIR = "./results"
47
+ HUB_MODEL_ID = "omark807/gemma3-finetuned-web-accessibility"
48
+ NUM_TRAIN_EPOCHS = 3
49
+ PER_DEVICE_TRAIN_BATCH_SIZE = 2
50
+ GRADIENT_ACCUMULATION_STEPS = 4
51
+ LEARNING_RATE = 2e-4
52
+ SAVE_STEPS = 500
53
+ LOGGING_STEPS = 10
54
+ MAX_SEQ_LENGTH = 512
55
+
56
+ # Create output directory
57
+ Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
58
+ logger.info(f"Output directory: {os.path.abspath(OUTPUT_DIR)}")
59
+
60
+ # --- Device Detection and Quantization Config ---
61
+ if torch.cuda.is_available():
62
+ logger.info("🚀 CUDA is available! Configuring for GPU training.")
63
+
64
+ try:
65
+ from bitsandbytes import BitsAndBytesConfig
66
+ logger.info("✓ BitsAndBytes imported successfully")
67
+
68
+ bnb_config = BitsAndBytesConfig(
69
+ load_in_4bit=True,
70
+ bnb_4bit_quant_type="nf4",
71
+ bnb_4bit_compute_dtype=torch.bfloat16,
72
+ bnb_4bit_use_double_quant=False,
73
+ )
74
+ model_dtype = torch.bfloat16
75
+ fp16_arg = False
76
+ bf16_arg = True
77
+ device_map = "auto"
78
+ optimizer_type = "paged_adamw_8bit"
79
+ logger.info("✓ 4-bit quantization configured")
80
+
81
+ except ImportError as e:
82
+ logger.warning(f"BitsAndBytes import failed: {e}")
83
+ logger.warning("Falling back to standard GPU configuration without quantization")
84
+ bnb_config = None
85
+ model_dtype = torch.float16 # Use float16 for GPU without quantization
86
+ fp16_arg = True
87
+ bf16_arg = False
88
+ device_map = {"": 0}
89
+ optimizer_type = "adamw_torch"
90
+
91
+ else:
92
+ logger.warning("⚠️ CUDA is NOT available. Using CPU configuration.")
93
+ logger.warning("Training will be significantly slower!")
94
+ bnb_config = None
95
+ model_dtype = torch.float32
96
+ fp16_arg = False
97
+ bf16_arg = False
98
+ device_map = "cpu"
99
+ optimizer_type = "adamw_torch"
100
+
101
+ # --- LoRA Configuration ---
102
+ lora_config = LoraConfig(
103
+ r=16,
104
+ lora_alpha=16,
105
+ target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
106
+ bias="none",
107
+ lora_dropout=0.05,
108
+ task_type="CAUSAL_LM",
109
+ )
110
+ logger.info("✓ LoRA configuration set")
111
+
112
+ # --- Load Dataset ---
113
+ logger.info("Loading dataset...")
114
+ try:
115
+ ds = load_dataset("omark807/web_a11y_dataset")
116
+ logger.info(f"✓ Dataset loaded. Train samples: {len(ds['train'])}")
117
+
118
+ sample = ds['train'][0]
119
+ if 'question' not in sample or 'answer' not in sample:
120
+ logger.error("Dataset must have 'question' and 'answer' columns")
121
+ sys.exit(1)
122
+
123
+ except Exception as e:
124
+ logger.error(f"Failed to load dataset: {e}")
125
+ logger.error("Check your internet connection and dataset availability")
126
+ sys.exit(1)
127
+
128
+ # --- Load Tokenizer ---
129
+ logger.info(f"Loading tokenizer: {MODEL_ID}")
130
+ try:
131
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
132
+
133
+ # Handle tokenizer padding
134
+ if tokenizer.pad_token is None:
135
+ tokenizer.pad_token = tokenizer.eos_token
136
+ tokenizer.padding_side = "right"
137
+ tokenizer.model_max_length = MAX_SEQ_LENGTH
138
+ logger.info("✓ Tokenizer loaded and configured")
139
+
140
+ except Exception as e:
141
+ logger.error(f"Failed to load tokenizer: {e}")
142
+ sys.exit(1)
143
+
144
+ # --- Load Model ---
145
+ logger.info(f"Loading model: {MODEL_ID}")
146
+ try:
147
+ model_kwargs = {
148
+ "torch_dtype": model_dtype,
149
+ "device_map": device_map,
150
+ "trust_remote_code": True,
151
+ "use_cache": False,
152
+ }
153
+
154
+ # Add quantization config only if available
155
+ if bnb_config is not None:
156
+ model_kwargs["quantization_config"] = bnb_config
157
+
158
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **model_kwargs)
159
+
160
+ # Set pretraining_tp for Gemma
161
+ if hasattr(model.config, 'pretraining_tp'):
162
+ model.config.pretraining_tp = 1
163
+
164
+ logger.info("✓ Model loaded successfully")
165
+
166
+ except Exception as e:
167
+ logger.error(f"Failed to load model: {e}")
168
+ logger.error("This might be due to insufficient GPU memory or network issues")
169
+ sys.exit(1)
170
+
171
+ # --- Prepare Model for Training ---
172
+ logger.info("Preparing model for training...")
173
+ try:
174
+ # Prepare for k-bit training if using quantization
175
+ if bnb_config is not None:
176
+ model = prepare_model_for_kbit_training(model)
177
+ logger.info("✓ Model prepared for k-bit training")
178
+
179
+ # Apply LoRA
180
+ model = get_peft_model(model, lora_config)
181
+ logger.info("✓ LoRA applied to model")
182
+
183
+ for name, param in model.named_parameters():
184
+ if "lora" in name:
185
+ param.requires_grad = True
186
+ elif param.requires_grad:
187
+ param.requires_grad = False
188
+
189
+
190
+ if hasattr(model, 'lm_head'):
191
+ for param in model.lm_head.parameters():
192
+ param.requires_grad = True
193
+ elif hasattr(model, 'embed_out'):
194
+ for param in model.embed_out.parameters():
195
+ param.requires_grad = True
196
+ elif hasattr(model, 'base_model') and hasattr(model.base_model, 'lm_head'):
197
+ for param in model.base_model.lm_head.parameters():
198
+ param.requires_grad = True
199
+
200
+ if hasattr(model, 'get_input_embeddings') and model.get_input_embeddings() is not None:
201
+ model.get_input_embeddings().requires_grad_(False)
202
+ if hasattr(model, 'get_output_embeddings') and model.get_output_embeddings() is not None:
203
+ model.get_output_embeddings().requires_grad_(False)
204
+
205
+ model.print_trainable_parameters() # This will reflect the correct trainable params
206
+ logger.info("✓ Gradient requirements explicitly set for LoRA and LM head")
207
+
208
+
209
+ except Exception as e:
210
+ logger.error(f"Failed to prepare model: {e}")
211
+ logger.error(f"Full traceback: {traceback.format_exc()}")
212
+ sys.exit(1)
213
+
214
+ # --- Formatting Function (for pre-tokenization) ---
215
+ def tokenize_function(examples):
216
+
217
+ formatted_texts = []
218
+ for i in range(len(examples["question"])):
219
+ question = examples["question"][i]
220
+ answer = examples["answer"][i]
221
+ formatted_text = f"<start_of_turn>user\n{question}<end_of_turn>\n<start_of_turn>model\n{answer}<end_of_turn>"
222
+ formatted_texts.append(formatted_text)
223
+
224
+ # Tokenize the formatted texts directly
225
+ tokenized_inputs = tokenizer(
226
+ formatted_texts,
227
+ max_length=MAX_SEQ_LENGTH,
228
+ truncation=True,
229
+ padding="max_length",
230
+ return_tensors="np",
231
+ )
232
+
233
+ # Add 'labels' for language modeling training
234
+ tokenized_inputs["labels"] = tokenized_inputs["input_ids"].copy()
235
+ return tokenized_inputs
236
+
237
+ # --- Pre-tokenize the dataset ---
238
+ logger.info("Pre-tokenizing dataset...")
239
+ try:
240
+ tokenized_ds = ds["train"].map(
241
+ tokenize_function,
242
+ batched=True,
243
+ remove_columns=ds["train"].column_names,
244
+ num_proc=os.cpu_count() or 1,
245
+ )
246
+ logger.info(f"✓ Dataset pre-tokenized. New train samples: {len(tokenized_ds)}")
247
+ except Exception as e:
248
+ logger.error(f"Failed to pre-tokenize dataset: {e}")
249
+ logger.error(f"Full traceback: {traceback.format_exc()}")
250
+ sys.exit(1)
251
+
252
+ # --- Training Arguments ---
253
+ training_args = TrainingArguments(
254
+ output_dir=OUTPUT_DIR,
255
+ num_train_epochs=NUM_TRAIN_EPOCHS,
256
+ per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
257
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
258
+ optim=optimizer_type,
259
+ learning_rate=LEARNING_RATE,
260
+ fp16=fp16_arg,
261
+ bf16=bf16_arg,
262
+ max_grad_norm=0.3,
263
+ warmup_ratio=0.03,
264
+ lr_scheduler_type="constant",
265
+ logging_steps=LOGGING_STEPS,
266
+ save_steps=SAVE_STEPS,
267
+ save_total_limit=3,
268
+ remove_unused_columns=False,
269
+ push_to_hub=False,
270
+ hub_model_id=HUB_MODEL_ID,
271
+ report_to="tensorboard",
272
+ dataloader_num_workers=0,
273
+ save_safetensors=True,
274
+ gradient_checkpointing=False,
275
+ )
276
+ logger.info("✓ Training arguments configured")
277
+
278
+ # --- Initialize Trainer ---
279
+ logger.info("Initializing SFTTrainer...")
280
+ try:
281
+ trainer = SFTTrainer(
282
+ model=model,
283
+ train_dataset=tokenized_ds,
284
+ args=training_args,
285
+ )
286
+ logger.info("✓ SFTTrainer initialized successfully")
287
+
288
+ except Exception as e:
289
+ logger.error(f"Failed to initialize trainer: {e}")
290
+ logger.error(f"Full traceback: {traceback.format_exc()}") # Added traceback for debugging
291
+ sys.exit(1)
292
+
293
+ # --- Start Training ---
294
+ logger.info("🚀 Starting fine-tuning...")
295
+ logger.info(f"Training for {NUM_TRAIN_EPOCHS} epochs")
296
+ logger.info(f"Batch size: {PER_DEVICE_TRAIN_BATCH_SIZE}, Gradient accumulation: {GRADIENT_ACCUMULATION_STEPS}")
297
+ logger.info(f"Effective batch size: {PER_DEVICE_TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
298
+
299
+ try:
300
+ trainer.train()
301
+ logger.info("🎉 Fine-tuning completed successfully!")
302
+
303
+ except Exception as e:
304
+ logger.error(f"Training failed: {e}")
305
+ logger.error(f"Full traceback: {traceback.format_exc()}")
306
+ sys.exit(1)
307
+
308
+ # --- Save Model ---
309
+ logger.info("Saving model and tokenizer...")
310
+ try:
311
+ trainer.save_model(OUTPUT_DIR)
312
+ tokenizer.save_pretrained(OUTPUT_DIR)
313
+ logger.info(f"✓ Model saved to: {os.path.abspath(OUTPUT_DIR)}")
314
+
315
+ # Save training info
316
+ with open(os.path.join(OUTPUT_DIR, "training_info.txt"), "w") as f:
317
+ f.write(f"Model: {MODEL_ID}\n")
318
+ f.write(f"Epochs: {NUM_TRAIN_EPOCHS}\n")
319
+ f.write(f"Learning rate: {LEARNING_RATE}\n")
320
+ f.write(f"Batch size: {PER_DEVICE_TRAIN_BATCH_SIZE}\n")
321
+ f.write(f"LoRA r: {lora_config.r}\n")
322
+ f.write(f"Device: {'GPU' if torch.cuda.is_available() else 'CPU'}\n")
323
+ f.write(f"Quantization: {bnb_config is not None}\n")
324
+
325
+ logger.info("✅ All done! Model ready for use.")
326
+
327
+ except Exception as e:
328
+ logger.error(f"Failed to save model: {e}")
329
+ sys.exit(1)
330
+
331
+ except KeyboardInterrupt:
332
+ logger.info("Training interrupted by user")
333
+ sys.exit(1)
334
+ except Exception as e:
335
+ logger.error(f"Unexpected error: {e}")
336
+ logger.error(f"Full traceback: {traceback.format_exc()}")
337
+ sys.exit(1)
338
+
339
+ if __name__ == "__main__":
340
+ main()