AvocadoMuffin commited on
Commit
22050d0
Β·
verified Β·
1 Parent(s): 5b5e488

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +168 -82
train.py CHANGED
@@ -1,12 +1,14 @@
1
  #!/usr/bin/env python
2
- # train_cuad_lora.py
3
  """
4
  CUAD fine-tune with LoRA on an L4 / T4 GPU.
 
5
  Expected wall-clock on Nvidia L4: ~25-30 min.
6
  """
7
 
8
- import os, json, random, gc
9
  from collections import defaultdict
 
10
 
11
  import torch, numpy as np
12
  from datasets import load_dataset, Dataset, disable_caching
@@ -14,7 +16,6 @@ from transformers import (
14
  AutoTokenizer, AutoModelForQuestionAnswering,
15
  TrainingArguments, default_data_collator, Trainer
16
  )
17
- # FIXED: Use regular Trainer instead of QuestionAnsweringTrainer
18
  from peft import LoraConfig, get_peft_model, TaskType
19
  import evaluate
20
  from huggingface_hub import login
@@ -26,11 +27,25 @@ disable_caching() # avoids giant disk cache on Colab
26
  MAX_LEN = 384 # window
27
  DOC_STRIDE = 128
28
  SEED = 42
 
29
 
30
  def set_seed(seed):
31
  random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
32
  torch.cuda.manual_seed_all(seed)
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def balance_has_answer(dataset, ratio=2.0):
35
  """Keep all has-answer rows, down-sample no-answer rows to `ratio`."""
36
  has, no = [], []
@@ -81,7 +96,7 @@ def postprocess_qa(examples, features, raw_predictions, tokenizer):
81
  return predictions
82
 
83
  def compute_metrics(eval_pred):
84
- """FIXED: Use regular eval_pred structure and correct variable names"""
85
  predictions = postprocess_qa(val_raw, val_feats, eval_pred.predictions, tok)
86
  references = [
87
  {"id": ex["id"], "answers": ex["answers"]} for ex in val_raw
@@ -90,55 +105,9 @@ def compute_metrics(eval_pred):
90
 
91
  # ───────────────────────────────────────────────────────────────── main ──
92
 
93
- def main():
94
- global val_raw, val_feats, tok # FIXED: Use correct variable names
95
 
96
- set_seed(SEED)
97
-
98
- # model name to store on Hub
99
- model_repo = os.getenv("MODEL_NAME", "AvocadoMuffin/roberta-cuad-qa-v2")
100
-
101
- if (tokn := os.getenv("roberta_token")):
102
- try:
103
- login(tokn)
104
- print("πŸ”‘ HuggingFace Hub login OK")
105
- except Exception as e:
106
- print(f"⚠️ Hub login failed: {e}")
107
- print("πŸ“ Training will continue but won't push to Hub")
108
- tokn = None # Disable pushing
109
-
110
- print("πŸ“š Loading CUAD…")
111
- try:
112
- cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True)
113
- print(f"βœ… Loaded {len(cuad)} examples")
114
- except Exception as e:
115
- print(f"❌ Dataset loading failed: {e}")
116
- print("πŸ”„ Retrying with cache disabled...")
117
- cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True, download_mode="force_redownload")
118
-
119
- cuad = cuad.shuffle(seed=SEED)
120
- cuad = balance_has_answer(cuad, ratio=2.0) # β‰ˆ18 k rows
121
- print(f"πŸ“Š Balanced dataset: {len(cuad)} examples")
122
-
123
- # train / val 90-10
124
- ds = cuad.train_test_split(test_size=0.1, seed=SEED)
125
- train_raw, val_raw = ds["train"], ds["test"]
126
-
127
- # ── tokeniser & model (SQuAD-2 tuned) ───────────────────────────────
128
- base_ckpt = "deepset/roberta-base-squad2"
129
- tok = AutoTokenizer.from_pretrained(base_ckpt, use_fast=True)
130
- model = AutoModelForQuestionAnswering.from_pretrained(base_ckpt)
131
-
132
- # LoRA
133
- lora = LoraConfig(
134
- task_type=TaskType.QUESTION_ANS,
135
- r=16, lora_alpha=32, lora_dropout=0.05,
136
- target_modules=["query", "value"],
137
- )
138
- model = get_peft_model(model, lora)
139
- model.print_trainable_parameters()
140
-
141
- # ── preprocess (OPTIMIZED) ─────────────────────────────────────────
142
  def preprocess(examples):
143
  # Tokenize all at once
144
  tokenized = tok(
@@ -204,29 +173,115 @@ def main():
204
  tokenized["example_id"] = example_ids
205
  return tokenized
206
 
207
- print("πŸ”„ Preprocessing training data...")
208
- train_feats = train_raw.map(
209
- preprocess,
210
- batched=True,
211
- remove_columns=train_raw.column_names,
212
- num_proc=4, # Use multiple processes for speed
213
- desc="tokenise-train",
214
- load_from_cache_file=False,
215
- batch_size=100 # Process in smaller batches
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  )
 
 
 
 
 
217
  # Remove offset_mapping from training data (not needed during training)
218
- train_feats = train_feats.remove_columns(["offset_mapping"])
 
219
 
220
- print("πŸ”„ Preprocessing validation data...")
221
- val_feats = val_raw.map(
222
- preprocess,
223
- batched=True,
224
- remove_columns=val_raw.column_names,
225
- num_proc=4, # Use multiple processes for speed
226
- desc="tokenise-val",
227
- load_from_cache_file=False,
228
- batch_size=100 # Process in smaller batches
229
- )
230
  # Keep offset_mapping for validation (needed for postprocessing)
231
 
232
  # ── training args ──────────────────────────────────────────────────
@@ -250,9 +305,13 @@ def main():
250
  greater_is_better=True,
251
  logging_steps=50,
252
  report_to="none",
 
 
 
 
 
253
  )
254
 
255
- # FIXED: Use regular Trainer instead of QuestionAnsweringTrainer
256
  trainer = Trainer(
257
  model=model,
258
  args=args,
@@ -264,7 +323,19 @@ def main():
264
  )
265
 
266
  print("πŸš€ Training…")
267
- trainer.train()
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
  print("βœ… Done. Best F1:", trainer.state.best_metric)
270
  trainer.save_model("./cuad_lora_out")
@@ -272,16 +343,31 @@ def main():
272
 
273
  # optional: push (with retry logic)
274
  if tokn:
275
- try:
276
- print("⬆️ Pushing to Hub...")
277
- trainer.push_to_hub(model_repo, private=False)
278
- tok.push_to_hub(model_repo, private=False)
279
- print("πŸš€ Pushed to:", f"https://huggingface.co/{model_repo}")
280
- except Exception as e:
281
- print(f"⚠️ Hub push failed: {e}")
282
- print("πŸ’Ύ Model saved locally in ./cuad_lora_out")
 
 
 
 
 
 
 
283
  else:
284
  print("πŸ’Ύ Model saved locally in ./cuad_lora_out (no HF token for push)")
285
 
 
 
 
 
 
 
 
 
286
  if __name__ == "__main__":
287
  main()
 
1
  #!/usr/bin/env python
2
+ # train_cuad_lora_improved.py
3
  """
4
  CUAD fine-tune with LoRA on an L4 / T4 GPU.
5
+ Improved version with better error handling and recovery mechanisms.
6
  Expected wall-clock on Nvidia L4: ~25-30 min.
7
  """
8
 
9
+ import os, json, random, gc, time
10
  from collections import defaultdict
11
+ from pathlib import Path
12
 
13
  import torch, numpy as np
14
  from datasets import load_dataset, Dataset, disable_caching
 
16
  AutoTokenizer, AutoModelForQuestionAnswering,
17
  TrainingArguments, default_data_collator, Trainer
18
  )
 
19
  from peft import LoraConfig, get_peft_model, TaskType
20
  import evaluate
21
  from huggingface_hub import login
 
27
  MAX_LEN = 384 # window
28
  DOC_STRIDE = 128
29
  SEED = 42
30
+ CHECKPOINT_DIR = "./cuad_lora_checkpoints"
31
 
32
  def set_seed(seed):
33
  random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
34
  torch.cuda.manual_seed_all(seed)
35
 
36
+ def save_checkpoint(data, checkpoint_path):
37
+ """Save preprocessing checkpoint to disk"""
38
+ os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
39
+ torch.save(data, checkpoint_path)
40
+ print(f"πŸ’Ύ Checkpoint saved: {checkpoint_path}")
41
+
42
+ def load_checkpoint(checkpoint_path):
43
+ """Load preprocessing checkpoint from disk"""
44
+ if os.path.exists(checkpoint_path):
45
+ print(f"πŸ“‚ Loading checkpoint: {checkpoint_path}")
46
+ return torch.load(checkpoint_path)
47
+ return None
48
+
49
  def balance_has_answer(dataset, ratio=2.0):
50
  """Keep all has-answer rows, down-sample no-answer rows to `ratio`."""
51
  has, no = [], []
 
96
  return predictions
97
 
98
  def compute_metrics(eval_pred):
99
+ """Use regular eval_pred structure and correct variable names"""
100
  predictions = postprocess_qa(val_raw, val_feats, eval_pred.predictions, tok)
101
  references = [
102
  {"id": ex["id"], "answers": ex["answers"]} for ex in val_raw
 
105
 
106
  # ───────────────────────────────────────────────────────────────── main ──
107
 
108
+ def preprocess_with_retry(dataset, dataset_name, max_retries=3):
109
+ """Preprocess dataset with retry logic and checkpointing"""
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  def preprocess(examples):
112
  # Tokenize all at once
113
  tokenized = tok(
 
173
  tokenized["example_id"] = example_ids
174
  return tokenized
175
 
176
+ checkpoint_path = f"{CHECKPOINT_DIR}/{dataset_name}_features.pt"
177
+
178
+ # Try to load from checkpoint first
179
+ features = load_checkpoint(checkpoint_path)
180
+ if features is not None:
181
+ print(f"βœ… Loaded {dataset_name} features from checkpoint")
182
+ return Dataset.from_dict(features)
183
+
184
+ # Process with retries
185
+ for attempt in range(max_retries):
186
+ try:
187
+ print(f"πŸ”„ Preprocessing {dataset_name} data (attempt {attempt + 1}/{max_retries})...")
188
+
189
+ # Use smaller batch sizes and reduce num_proc for stability
190
+ features = dataset.map(
191
+ preprocess,
192
+ batched=True,
193
+ remove_columns=dataset.column_names,
194
+ num_proc=2, # Reduced from 4 for stability
195
+ desc=f"tokenise-{dataset_name}",
196
+ load_from_cache_file=False,
197
+ batch_size=50, # Reduced from 100 for stability
198
+ writer_batch_size=50 # Add writer batch size limit
199
+ )
200
+
201
+ # Save checkpoint after successful processing
202
+ save_checkpoint(features.to_dict(), checkpoint_path)
203
+ return features
204
+
205
+ except Exception as e:
206
+ print(f"❌ Preprocessing failed on attempt {attempt + 1}: {e}")
207
+ if attempt < max_retries - 1:
208
+ print(f"⏳ Waiting 10 seconds before retry...")
209
+ time.sleep(10)
210
+ gc.collect() # Clean up memory
211
+ else:
212
+ print("πŸ’₯ All preprocessing attempts failed!")
213
+ raise e
214
+
215
+ def main():
216
+ global val_raw, val_feats, tok
217
+
218
+ set_seed(SEED)
219
+
220
+ # Create checkpoint directory
221
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
222
+
223
+ # Model name to store on Hub
224
+ model_repo = os.getenv("MODEL_NAME", "AvocadoMuffin/roberta-cuad-qa-v2")
225
+
226
+ if (tokn := os.getenv("roberta_token")):
227
+ try:
228
+ login(tokn)
229
+ print("πŸ”‘ HuggingFace Hub login OK")
230
+ except Exception as e:
231
+ print(f"⚠️ Hub login failed: {e}")
232
+ print("πŸ“ Training will continue but won't push to Hub")
233
+ tokn = None # Disable pushing
234
+
235
+ print("πŸ“š Loading CUAD…")
236
+ dataset_checkpoint = f"{CHECKPOINT_DIR}/cuad_dataset.pt"
237
+
238
+ # Try to load dataset from checkpoint
239
+ dataset_data = load_checkpoint(dataset_checkpoint)
240
+ if dataset_data is not None:
241
+ cuad = Dataset.from_dict(dataset_data)
242
+ print(f"βœ… Loaded dataset from checkpoint: {len(cuad)} examples")
243
+ else:
244
+ # Load and process dataset
245
+ try:
246
+ cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True)
247
+ print(f"βœ… Loaded {len(cuad)} examples")
248
+ except Exception as e:
249
+ print(f"❌ Dataset loading failed: {e}")
250
+ print("πŸ”„ Retrying with cache disabled...")
251
+ cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True, download_mode="force_redownload")
252
+
253
+ cuad = cuad.shuffle(seed=SEED)
254
+ cuad = balance_has_answer(cuad, ratio=2.0) # β‰ˆ18 k rows
255
+ print(f"πŸ“Š Balanced dataset: {len(cuad)} examples")
256
+
257
+ # Save dataset checkpoint
258
+ save_checkpoint(cuad.to_dict(), dataset_checkpoint)
259
+
260
+ # train / val 90-10
261
+ ds = cuad.train_test_split(test_size=0.1, seed=SEED)
262
+ train_raw, val_raw = ds["train"], ds["test"]
263
+
264
+ # ── tokeniser & model (SQuAD-2 tuned) ───────────────────────────────
265
+ base_ckpt = "deepset/roberta-base-squad2"
266
+ tok = AutoTokenizer.from_pretrained(base_ckpt, use_fast=True)
267
+ model = AutoModelForQuestionAnswering.from_pretrained(base_ckpt)
268
+
269
+ # LoRA
270
+ lora = LoraConfig(
271
+ task_type=TaskType.QUESTION_ANS,
272
+ r=16, lora_alpha=32, lora_dropout=0.05,
273
+ target_modules=["query", "value"],
274
  )
275
+ model = get_peft_model(model, lora)
276
+ model.print_trainable_parameters()
277
+
278
+ # ── preprocess with retry logic ─────────────────────────────────────────
279
+ train_feats = preprocess_with_retry(train_raw, "train")
280
  # Remove offset_mapping from training data (not needed during training)
281
+ if "offset_mapping" in train_feats.column_names:
282
+ train_feats = train_feats.remove_columns(["offset_mapping"])
283
 
284
+ val_feats = preprocess_with_retry(val_raw, "val")
 
 
 
 
 
 
 
 
 
285
  # Keep offset_mapping for validation (needed for postprocessing)
286
 
287
  # ── training args ──────────────────────────────────────────────────
 
305
  greater_is_better=True,
306
  logging_steps=50,
307
  report_to="none",
308
+ # Add resume from checkpoint capability
309
+ resume_from_checkpoint=True,
310
+ # Add dataloader settings for stability
311
+ dataloader_num_workers=0, # Disable multiprocessing for data loading
312
+ dataloader_pin_memory=False, # Reduce memory pressure
313
  )
314
 
 
315
  trainer = Trainer(
316
  model=model,
317
  args=args,
 
323
  )
324
 
325
  print("πŸš€ Training…")
326
+ try:
327
+ trainer.train()
328
+ print("βœ… Training completed successfully!")
329
+ except Exception as e:
330
+ print(f"❌ Training failed: {e}")
331
+ print("πŸ’Ύ Attempting to save current state...")
332
+ try:
333
+ trainer.save_model("./cuad_lora_out_partial")
334
+ tok.save_pretrained("./cuad_lora_out_partial")
335
+ print("πŸ’Ύ Partial model saved to ./cuad_lora_out_partial")
336
+ except:
337
+ print("❌ Could not save partial model")
338
+ raise e
339
 
340
  print("βœ… Done. Best F1:", trainer.state.best_metric)
341
  trainer.save_model("./cuad_lora_out")
 
343
 
344
  # optional: push (with retry logic)
345
  if tokn:
346
+ max_push_retries = 3
347
+ for push_attempt in range(max_push_retries):
348
+ try:
349
+ print(f"⬆️ Pushing to Hub (attempt {push_attempt + 1}/{max_push_retries})...")
350
+ trainer.push_to_hub(model_repo, private=False)
351
+ tok.push_to_hub(model_repo, private=False)
352
+ print("πŸš€ Pushed to:", f"https://huggingface.co/{model_repo}")
353
+ break
354
+ except Exception as e:
355
+ print(f"⚠️ Hub push failed on attempt {push_attempt + 1}: {e}")
356
+ if push_attempt < max_push_retries - 1:
357
+ print("⏳ Waiting 30 seconds before retry...")
358
+ time.sleep(30)
359
+ else:
360
+ print("πŸ’Ύ Model saved locally in ./cuad_lora_out (push failed)")
361
  else:
362
  print("πŸ’Ύ Model saved locally in ./cuad_lora_out (no HF token for push)")
363
 
364
+ # Clean up checkpoints after successful completion
365
+ try:
366
+ import shutil
367
+ shutil.rmtree(CHECKPOINT_DIR)
368
+ print("🧹 Cleaned up temporary checkpoints")
369
+ except:
370
+ print("⚠️ Could not clean up temporary checkpoints")
371
+
372
  if __name__ == "__main__":
373
  main()