AvocadoMuffin commited on
Commit
4337e1a
Β·
verified Β·
1 Parent(s): c48cc67

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +141 -248
train.py CHANGED
@@ -1,8 +1,8 @@
1
  #!/usr/bin/env python
2
- # train_cuad_lora_improved.py
3
  """
4
- CUAD fine-tune with LoRA on L40S GPU in HuggingFace Spaces.
5
- Improved version with better error handling and chunked processing.
6
  """
7
 
8
  import os, json, random, gc, time
@@ -21,72 +21,42 @@ from huggingface_hub import login
21
 
22
  disable_caching()
23
 
24
- # ─────────────────────────────────────────────────────────────── helpers ──
25
 
26
- MAX_LEN = 384
27
- DOC_STRIDE = 128
28
- SEED = 42
29
- CHECKPOINT_DIR = "./cuad_lora_checkpoints"
30
- CHUNK_SIZE = 100 # Process in smaller chunks to avoid timeouts
 
 
 
31
 
32
  def set_seed(seed):
33
  random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
34
  torch.cuda.manual_seed_all(seed)
35
 
36
- def save_checkpoint(data, checkpoint_path):
37
- """Save preprocessing checkpoint to disk"""
38
- os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
39
- torch.save(data, checkpoint_path)
40
- print(f"πŸ’Ύ Checkpoint saved: {checkpoint_path}")
41
-
42
- def load_checkpoint(checkpoint_path):
43
- """Load preprocessing checkpoint from disk"""
44
- if os.path.exists(checkpoint_path):
45
- print(f"πŸ“‚ Loading checkpoint: {checkpoint_path}")
46
- return torch.load(checkpoint_path, map_location='cpu')
47
- return None
48
-
49
- def save_partial_features(features_dict, chunk_idx, dataset_name):
50
- """Save partial features for a chunk"""
51
- partial_path = f"{CHECKPOINT_DIR}/{dataset_name}_chunk_{chunk_idx:04d}.pt"
52
- save_checkpoint(features_dict, partial_path)
53
- return partial_path
54
-
55
- def load_and_combine_chunks(dataset_name):
56
- """Load all chunk files and combine them"""
57
- chunk_files = []
58
- if os.path.exists(CHECKPOINT_DIR):
59
- for f in os.listdir(CHECKPOINT_DIR):
60
- if f.startswith(f"{dataset_name}_chunk_") and f.endswith('.pt'):
61
- chunk_files.append(os.path.join(CHECKPOINT_DIR, f))
62
-
63
- if not chunk_files:
64
- return None
65
-
66
- chunk_files.sort()
67
- print(f"πŸ“‚ Found {len(chunk_files)} chunks for {dataset_name}")
68
-
69
- # Combine all chunks
70
- combined = None
71
- for chunk_file in chunk_files:
72
- chunk_data = torch.load(chunk_file, map_location='cpu')
73
- if combined is None:
74
- combined = chunk_data
75
- else:
76
- for key in chunk_data:
77
- combined[key].extend(chunk_data[key])
78
-
79
- print(f"βœ… Combined {len(combined['input_ids'])} features from chunks")
80
- return combined
81
-
82
- def balance_has_answer(dataset, ratio=2.0):
83
  """Keep all has-answer rows, down-sample no-answer rows to `ratio`."""
84
  has, no = [], []
85
  for ex in dataset:
86
  (has if ex["answers"]["text"] else no).append(ex)
 
 
 
87
  k = int(len(has) * ratio)
88
  no = random.sample(no, min(k, len(no)))
89
- return Dataset.from_list(has + no)
 
 
 
 
 
 
 
 
 
 
90
 
91
  # ────────────────────────────────────────────────────────────── postproc ──
92
 
@@ -104,20 +74,20 @@ def postprocess_qa(examples, features, raw_predictions, tokenizer):
104
 
105
  for example_idx, example in enumerate(examples):
106
  best_score = -1e9
107
- best_span = ""
108
- context = example["context"]
109
 
110
  for feat_idx in features_per_example[example_idx]:
111
  start_logit = all_start[feat_idx]
112
- end_logit = all_end[feat_idx]
113
- offset = features["offset_mapping"][feat_idx]
114
 
115
  start_idx = int(np.argmax(start_logit))
116
- end_idx = int(np.argmax(end_logit))
117
 
118
  if start_idx <= end_idx < len(offset):
119
  start_char, _ = offset[start_idx]
120
- _, end_char = offset[end_idx]
121
  span = context[start_char:end_char].strip()
122
  score = start_logit[start_idx] + end_logit[end_idx]
123
  if score > best_score and span:
@@ -131,19 +101,25 @@ def postprocess_qa(examples, features, raw_predictions, tokenizer):
131
  def compute_metrics(eval_pred):
132
  """Use regular eval_pred structure and correct variable names"""
133
  predictions = postprocess_qa(val_raw, val_feats, eval_pred.predictions, tok)
134
- references = [
135
  {"id": ex["id"], "answers": ex["answers"]} for ex in val_raw
136
  ]
137
  return metric.compute(predictions=predictions, references=references)
138
 
139
- # ───────────────────────────────────────────────────────────────── main ──
140
 
141
- def preprocess_single_example(example, tokenizer):
142
- """Process a single example to avoid batch processing issues"""
143
- # Tokenize
144
- tokenized = tokenizer(
145
- example["question"],
146
- example["context"],
 
 
 
 
 
 
147
  truncation="only_second",
148
  max_length=MAX_LEN,
149
  stride=DOC_STRIDE,
@@ -152,160 +128,89 @@ def preprocess_single_example(example, tokenizer):
152
  padding="max_length",
153
  )
154
 
155
- results = {
156
- "input_ids": [],
157
- "attention_mask": [],
158
- "start_positions": [],
159
- "end_positions": [],
160
- "example_id": [],
161
- "offset_mapping": []
162
- }
163
 
164
- for i in range(len(tokenized["input_ids"])):
165
- results["input_ids"].append(tokenized["input_ids"][i])
166
- results["attention_mask"].append(tokenized["attention_mask"][i])
167
- results["offset_mapping"].append(tokenized["offset_mapping"][i])
168
- results["example_id"].append(example["id"])
169
 
170
- # Handle answer positions
171
- answers = example["answers"]
172
- offsets = tokenized["offset_mapping"][i]
173
- cls_index = 0
174
 
 
175
  if not answers["text"] or not answers["text"][0]:
176
- results["start_positions"].append(cls_index)
177
- results["end_positions"].append(cls_index)
178
  continue
179
-
180
- answer_start = answers["answer_start"][0]
 
181
  answer_text = answers["text"][0]
182
- answer_end = answer_start + len(answer_text)
183
 
184
- start_token = end_token = cls_index
 
 
185
 
186
- for tok_idx, (start_char, end_char) in enumerate(offsets):
187
- if start_char <= answer_start < end_char:
188
- start_token = tok_idx
189
- if start_char < answer_end <= end_char:
190
- end_token = tok_idx
191
  break
192
 
193
- if start_token <= end_token and start_token > 0:
194
- results["start_positions"].append(start_token)
195
- results["end_positions"].append(end_token)
 
196
  else:
197
- results["start_positions"].append(cls_index)
198
- results["end_positions"].append(cls_index)
199
-
200
- return results
201
-
202
- def preprocess_with_chunking(dataset, dataset_name, tokenizer):
203
- """Process dataset in chunks to avoid timeouts"""
204
 
205
- # Check if final result already exists
206
- final_checkpoint = f"{CHECKPOINT_DIR}/{dataset_name}_features.pt"
207
- final_features = load_checkpoint(final_checkpoint)
208
- if final_features is not None:
209
- print(f"βœ… Loaded {dataset_name} features from final checkpoint")
210
- return Dataset.from_dict(final_features)
211
 
212
- # Check if we can resume from chunks
213
- combined_features = load_and_combine_chunks(dataset_name)
214
- if combined_features is not None:
215
- # Save as final checkpoint
216
- save_checkpoint(combined_features, final_checkpoint)
217
- return Dataset.from_dict(combined_features)
218
-
219
- # Process in chunks
220
- print(f"πŸ”„ Processing {dataset_name} dataset in chunks of {CHUNK_SIZE}...")
221
-
222
- total_samples = len(dataset)
223
- num_chunks = (total_samples + CHUNK_SIZE - 1) // CHUNK_SIZE
224
-
225
- for chunk_idx in range(num_chunks):
226
- chunk_file = f"{CHECKPOINT_DIR}/{dataset_name}_chunk_{chunk_idx:04d}.pt"
227
-
228
- # Skip if chunk already processed
229
- if os.path.exists(chunk_file):
230
- print(f"⏭️ Chunk {chunk_idx + 1}/{num_chunks} already exists, skipping...")
231
- continue
232
-
233
- start_idx = chunk_idx * CHUNK_SIZE
234
- end_idx = min(start_idx + CHUNK_SIZE, total_samples)
235
-
236
- print(f"πŸ”„ Processing chunk {chunk_idx + 1}/{num_chunks} (samples {start_idx}-{end_idx-1})...")
237
-
238
- chunk_results = {
239
- "input_ids": [],
240
- "attention_mask": [],
241
- "start_positions": [],
242
- "end_positions": [],
243
- "example_id": [],
244
- "offset_mapping": []
245
- }
246
-
247
- # Process each example in the chunk individually
248
- for i in range(start_idx, end_idx):
249
- if i % 10 == 0: # Progress indicator
250
- print(f" Processing sample {i}/{total_samples}")
251
-
252
- try:
253
- example = dataset[i]
254
- result = preprocess_single_example(example, tokenizer)
255
-
256
- # Add to chunk results
257
- for key in chunk_results:
258
- chunk_results[key].extend(result[key])
259
-
260
- except Exception as e:
261
- print(f"⚠️ Error processing sample {i}: {e}")
262
- continue
263
-
264
- # Save chunk
265
- save_partial_features(chunk_results, chunk_idx, dataset_name)
266
-
267
- # Clean up memory
268
- del chunk_results
269
- gc.collect()
270
-
271
- print(f"βœ… Chunk {chunk_idx + 1}/{num_chunks} completed and saved")
272
-
273
- # Combine all chunks
274
- print("πŸ”„ Combining all chunks...")
275
- combined_features = load_and_combine_chunks(dataset_name)
276
-
277
- if combined_features is None:
278
- raise RuntimeError("Failed to load and combine chunks!")
279
 
280
- # Save final result
281
- save_checkpoint(combined_features, final_checkpoint)
 
 
 
 
 
 
282
 
283
- # Clean up chunk files
284
- cleanup_chunk_files(dataset_name)
 
 
 
 
 
 
285
 
286
- return Dataset.from_dict(combined_features)
287
-
288
- def cleanup_chunk_files(dataset_name):
289
- """Remove chunk files after successful combination"""
290
- if os.path.exists(CHECKPOINT_DIR):
291
- for f in os.listdir(CHECKPOINT_DIR):
292
- if f.startswith(f"{dataset_name}_chunk_") and f.endswith('.pt'):
293
- try:
294
- os.remove(os.path.join(CHECKPOINT_DIR, f))
295
- except:
296
- pass
297
- print(f"🧹 Cleaned up chunk files for {dataset_name}")
298
 
299
  def main():
300
  global val_raw, val_feats, tok
301
 
302
  set_seed(SEED)
303
-
304
- # Create checkpoint directory
305
- os.makedirs(CHECKPOINT_DIR, exist_ok=True)
306
 
307
  # Model name to store on Hub
308
- model_repo = os.getenv("MODEL_NAME", "AvocadoMuffin/roberta-cuad-qa-v2")
309
 
310
  if (tokn := os.getenv("roberta_token")):
311
  try:
@@ -316,27 +221,19 @@ def main():
316
  tokn = None
317
 
318
  print("πŸ“š Loading CUAD…")
319
- dataset_checkpoint = f"{CHECKPOINT_DIR}/cuad_dataset.pt"
 
 
 
 
 
320
 
321
- # Try to load dataset from checkpoint
322
- dataset_data = load_checkpoint(dataset_checkpoint)
323
- if dataset_data is not None:
324
- cuad = Dataset.from_dict(dataset_data)
325
- print(f"βœ… Loaded dataset from checkpoint: {len(cuad)} examples")
326
- else:
327
- try:
328
- cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True)
329
- print(f"βœ… Loaded {len(cuad)} examples")
330
- except Exception as e:
331
- print(f"❌ Dataset loading failed: {e}")
332
- cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True, download_mode="force_redownload")
333
-
334
- cuad = cuad.shuffle(seed=SEED)
335
- cuad = balance_has_answer(cuad, ratio=2.0)
336
- print(f"πŸ“Š Balanced dataset: {len(cuad)} examples")
337
-
338
- # Save dataset checkpoint
339
- save_checkpoint(cuad.to_dict(), dataset_checkpoint)
340
 
341
  # train / val 90-10
342
  ds = cuad.train_test_split(test_size=0.1, seed=SEED)
@@ -347,42 +244,46 @@ def main():
347
  tok = AutoTokenizer.from_pretrained(base_ckpt, use_fast=True)
348
  model = AutoModelForQuestionAnswering.from_pretrained(base_ckpt)
349
 
350
- # LoRA
351
  lora = LoraConfig(
352
  task_type=TaskType.QUESTION_ANS,
353
- r=16, lora_alpha=32, lora_dropout=0.05,
354
- target_modules=["query", "value"],
355
  )
356
  model = get_peft_model(model, lora)
357
  model.print_trainable_parameters()
358
 
359
- # ── preprocess with chunking ─────────────────────────────────────────
360
- print("πŸ”„ Starting preprocessing...")
361
 
362
- train_feats = preprocess_with_chunking(train_raw, "train", tok)
363
- # Remove offset_mapping from training data
 
364
  if "offset_mapping" in train_feats.column_names:
365
  train_feats = train_feats.remove_columns(["offset_mapping"])
366
 
367
- val_feats = preprocess_with_chunking(val_raw, "val", tok)
368
- # Keep offset_mapping for validation
369
 
370
  print(f"βœ… Preprocessing completed!")
371
  print(f" Training features: {len(train_feats)}")
372
  print(f" Validation features: {len(val_feats)}")
373
 
374
  # ── training args ──────────────────────────────────────────────────
 
 
 
375
  args = TrainingArguments(
376
  output_dir="./cuad_lora_out",
377
- learning_rate=3e-5,
378
- num_train_epochs=4,
379
- per_device_train_batch_size=16, # Increased for L40S
380
  per_device_eval_batch_size=16,
381
- gradient_accumulation_steps=2, # Reduced since batch size increased
382
  fp16=False, bf16=True,
383
  eval_strategy="steps",
384
- eval_steps=250,
385
- save_steps=500,
386
  save_total_limit=2,
387
  weight_decay=0.01,
388
  lr_scheduler_type="cosine",
@@ -390,11 +291,11 @@ def main():
390
  load_best_model_at_end=True,
391
  metric_for_best_model="f1",
392
  greater_is_better=True,
393
- logging_steps=50,
394
  report_to="none",
395
- resume_from_checkpoint=True,
396
- dataloader_num_workers=2, # L40S can handle more workers
397
  dataloader_pin_memory=True,
 
398
  )
399
 
400
  trainer = Trainer(
@@ -441,13 +342,5 @@ def main():
441
  else:
442
  print("πŸ’Ύ Model saved locally (push failed)")
443
 
444
- # Clean up checkpoints
445
- try:
446
- import shutil
447
- shutil.rmtree(CHECKPOINT_DIR)
448
- print("🧹 Cleaned up temporary checkpoints")
449
- except:
450
- print("⚠️ Could not clean up checkpoints")
451
-
452
  if __name__ == "__main__":
453
  main()
 
1
  #!/usr/bin/env python
2
+ # train_cuad_lora_efficient.py
3
  """
4
+ CUAD fine-tune with LoRA - Efficient batch processing version.
5
+ Fixes bottlenecks and uses proper batching instead of chunking.
6
  """
7
 
8
  import os, json, random, gc, time
 
21
 
22
  disable_caching()
23
 
24
+ # ─────────────────────────────────────────────────────────────── config ──
25
 
26
+ MAX_LEN = 384
27
+ DOC_STRIDE = 128
28
+ SEED = 42
29
+ BATCH_SIZE = 1000 # Process in larger, more efficient batches
30
+
31
+ # Reduced dataset size option
32
+ USE_SUBSET = True # Set to True to use only 10k examples
33
+ SUBSET_SIZE = 10000
34
 
35
  def set_seed(seed):
36
  random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
37
  torch.cuda.manual_seed_all(seed)
38
 
39
+ def balance_has_answer(dataset, ratio=2.0, max_samples=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  """Keep all has-answer rows, down-sample no-answer rows to `ratio`."""
41
  has, no = [], []
42
  for ex in dataset:
43
  (has if ex["answers"]["text"] else no).append(ex)
44
+
45
+ print(f"πŸ“Š Original: {len(has)} has-answer, {len(no)} no-answer")
46
+
47
  k = int(len(has) * ratio)
48
  no = random.sample(no, min(k, len(no)))
49
+
50
+ balanced = has + no
51
+
52
+ # Apply subset limit if specified
53
+ if max_samples and len(balanced) > max_samples:
54
+ balanced = random.sample(balanced, max_samples)
55
+ print(f"πŸ“‰ Reduced to {max_samples} samples for faster training")
56
+
57
+ print(f"πŸ“Š Balanced: {len([x for x in balanced if x['answers']['text']])} has-answer, {len([x for x in balanced if not x['answers']['text']])} no-answer")
58
+
59
+ return Dataset.from_list(balanced)
60
 
61
  # ────────────────────────────────────────────────────────────── postproc ──
62
 
 
74
 
75
  for example_idx, example in enumerate(examples):
76
  best_score = -1e9
77
+ best_span = ""
78
+ context = example["context"]
79
 
80
  for feat_idx in features_per_example[example_idx]:
81
  start_logit = all_start[feat_idx]
82
+ end_logit = all_end[feat_idx]
83
+ offset = features["offset_mapping"][feat_idx]
84
 
85
  start_idx = int(np.argmax(start_logit))
86
+ end_idx = int(np.argmax(end_logit))
87
 
88
  if start_idx <= end_idx < len(offset):
89
  start_char, _ = offset[start_idx]
90
+ _, end_char = offset[end_idx]
91
  span = context[start_char:end_char].strip()
92
  score = start_logit[start_idx] + end_logit[end_idx]
93
  if score > best_score and span:
 
101
  def compute_metrics(eval_pred):
102
  """Use regular eval_pred structure and correct variable names"""
103
  predictions = postprocess_qa(val_raw, val_feats, eval_pred.predictions, tok)
104
+ references = [
105
  {"id": ex["id"], "answers": ex["answers"]} for ex in val_raw
106
  ]
107
  return metric.compute(predictions=predictions, references=references)
108
 
109
+ # ───────────────────────────────────────────────────────────── preprocessing ──
110
 
111
+ def preprocess_batch_efficient(examples, tokenizer):
112
+ """
113
+ Efficient batch preprocessing using HuggingFace's built-in batch processing.
114
+ This is much faster than processing examples individually.
115
+ """
116
+ questions = examples["question"]
117
+ contexts = examples["context"]
118
+
119
+ # Batch tokenization - this is the key efficiency gain
120
+ tokenized_examples = tokenizer(
121
+ questions,
122
+ contexts,
123
  truncation="only_second",
124
  max_length=MAX_LEN,
125
  stride=DOC_STRIDE,
 
128
  padding="max_length",
129
  )
130
 
131
+ # Map back to original examples
132
+ sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
133
+
134
+ # Initialize output
135
+ start_positions = []
136
+ end_positions = []
 
 
137
 
138
+ for i, offsets in enumerate(tokenized_examples["offset_mapping"]):
139
+ input_ids = tokenized_examples["input_ids"][i]
140
+ cls_index = 0 # CLS token position
 
 
141
 
142
+ # Get the original example for this tokenized chunk
143
+ sample_index = sample_mapping[i]
144
+ answers = examples["answers"][sample_index]
 
145
 
146
+ # Handle cases with no answer
147
  if not answers["text"] or not answers["text"][0]:
148
+ start_positions.append(cls_index)
149
+ end_positions.append(cls_index)
150
  continue
151
+
152
+ # Find answer span in tokens
153
+ answer_start_char = answers["answer_start"][0]
154
  answer_text = answers["text"][0]
155
+ answer_end_char = answer_start_char + len(answer_text)
156
 
157
+ # Find token positions
158
+ token_start_index = cls_index
159
+ token_end_index = cls_index
160
 
161
+ for token_index, (start_char, end_char) in enumerate(offsets):
162
+ if start_char <= answer_start_char < end_char:
163
+ token_start_index = token_index
164
+ if start_char < answer_end_char <= end_char:
165
+ token_end_index = token_index
166
  break
167
 
168
+ # Validate positions
169
+ if token_start_index <= token_end_index and token_start_index > 0:
170
+ start_positions.append(token_start_index)
171
+ end_positions.append(token_end_index)
172
  else:
173
+ start_positions.append(cls_index)
174
+ end_positions.append(cls_index)
 
 
 
 
 
175
 
176
+ tokenized_examples["start_positions"] = start_positions
177
+ tokenized_examples["end_positions"] = end_positions
 
 
 
 
178
 
179
+ # Add example IDs for evaluation
180
+ tokenized_examples["example_id"] = [
181
+ examples["id"][sample_mapping[i]] for i in range(len(tokenized_examples["input_ids"]))
182
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
+ return tokenized_examples
185
+
186
+ def preprocess_dataset_streaming(dataset, tokenizer, desc="Processing"):
187
+ """
188
+ Process dataset in batches using HuggingFace's map function with batching.
189
+ This is much more memory efficient and faster than manual chunking.
190
+ """
191
+ print(f"πŸ”„ {desc} dataset with batch processing...")
192
 
193
+ processed = dataset.map(
194
+ lambda examples: preprocess_batch_efficient(examples, tokenizer),
195
+ batched=True,
196
+ batch_size=BATCH_SIZE,
197
+ remove_columns=dataset.column_names,
198
+ desc=desc,
199
+ num_proc=1, # Use 1 process to avoid memory issues in Spaces
200
+ )
201
 
202
+ print(f"βœ… {desc} completed: {len(processed)} features")
203
+ return processed
204
+
205
+ # ─────────────────────────────────────────────────���─────────────── main ──
 
 
 
 
 
 
 
 
206
 
207
  def main():
208
  global val_raw, val_feats, tok
209
 
210
  set_seed(SEED)
 
 
 
211
 
212
  # Model name to store on Hub
213
+ model_repo = os.getenv("MODEL_NAME", "AvocadoMuffin/roberta-cuad-qa-v3")
214
 
215
  if (tokn := os.getenv("roberta_token")):
216
  try:
 
221
  tokn = None
222
 
223
  print("πŸ“š Loading CUAD…")
224
+ try:
225
+ cuad = load_dataset("theatricusproject/cuad-qa", split="train", trust_remote_code=True)
226
+ print(f"βœ… Loaded {len(cuad)} examples")
227
+ except Exception as e:
228
+ print(f"❌ Dataset loading failed: {e}")
229
+ cuad = load_dataset("theatricusproject/cuad-qa", split="train", trust_remote_code=True, download_mode="force_redownload")
230
 
231
+ cuad = cuad.shuffle(seed=SEED)
232
+
233
+ # Apply subset reduction if enabled
234
+ subset_size = SUBSET_SIZE if USE_SUBSET else None
235
+ cuad = balance_has_answer(cuad, ratio=2.0, max_samples=subset_size)
236
+ print(f"πŸ“Š Final dataset size: {len(cuad)} examples")
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
  # train / val 90-10
239
  ds = cuad.train_test_split(test_size=0.1, seed=SEED)
 
244
  tok = AutoTokenizer.from_pretrained(base_ckpt, use_fast=True)
245
  model = AutoModelForQuestionAnswering.from_pretrained(base_ckpt)
246
 
247
+ # LoRA with slightly more aggressive settings for smaller dataset
248
  lora = LoraConfig(
249
  task_type=TaskType.QUESTION_ANS,
250
+ r=32, lora_alpha=64, lora_dropout=0.1, # Increased for better learning with less data
251
+ target_modules=["query", "value", "key", "dense"], # More modules for better coverage
252
  )
253
  model = get_peft_model(model, lora)
254
  model.print_trainable_parameters()
255
 
256
+ # ── efficient preprocessing ─────────────────────────────────────────
257
+ print("πŸ”„ Starting efficient preprocessing...")
258
 
259
+ # Process training data
260
+ train_feats = preprocess_dataset_streaming(train_raw, tok, "Training")
261
+ # Remove offset_mapping for training
262
  if "offset_mapping" in train_feats.column_names:
263
  train_feats = train_feats.remove_columns(["offset_mapping"])
264
 
265
+ # Process validation data (keep offset_mapping for evaluation)
266
+ val_feats = preprocess_dataset_streaming(val_raw, tok, "Validation")
267
 
268
  print(f"βœ… Preprocessing completed!")
269
  print(f" Training features: {len(train_feats)}")
270
  print(f" Validation features: {len(val_feats)}")
271
 
272
  # ── training args ──────────────────────────────────────────────────
273
+ # Adjusted for smaller dataset
274
+ total_steps = (len(train_feats) // 16 // 2) * 6 # Rough estimate
275
+
276
  args = TrainingArguments(
277
  output_dir="./cuad_lora_out",
278
+ learning_rate=5e-5, # Slightly higher for smaller dataset
279
+ num_train_epochs=6 if USE_SUBSET else 4, # More epochs for smaller dataset
280
+ per_device_train_batch_size=16,
281
  per_device_eval_batch_size=16,
282
+ gradient_accumulation_steps=2,
283
  fp16=False, bf16=True,
284
  eval_strategy="steps",
285
+ eval_steps=max(100, total_steps // 20), # Adaptive eval steps
286
+ save_steps=max(200, total_steps // 10), # Adaptive save steps
287
  save_total_limit=2,
288
  weight_decay=0.01,
289
  lr_scheduler_type="cosine",
 
291
  load_best_model_at_end=True,
292
  metric_for_best_model="f1",
293
  greater_is_better=True,
294
+ logging_steps=25,
295
  report_to="none",
296
+ dataloader_num_workers=2,
 
297
  dataloader_pin_memory=True,
298
+ remove_unused_columns=False, # Keep example_id for evaluation
299
  )
300
 
301
  trainer = Trainer(
 
342
  else:
343
  print("πŸ’Ύ Model saved locally (push failed)")
344
 
 
 
 
 
 
 
 
 
345
  if __name__ == "__main__":
346
  main()