AvocadoMuffin commited on
Commit
d8f4b0f
·
verified ·
1 Parent(s): 06dc84e

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +163 -375
train.py CHANGED
@@ -1,409 +1,197 @@
1
- import torch, gc, os, numpy as np, evaluate, json
2
- from datasets import load_dataset
 
 
 
 
 
 
 
 
 
 
3
  from transformers import (
4
  AutoTokenizer, AutoModelForQuestionAnswering,
5
- TrainingArguments, Trainer, default_data_collator
6
  )
 
7
  from peft import LoraConfig, get_peft_model, TaskType
 
8
  from huggingface_hub import login
9
- import sys
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def main():
12
- # Get model name from environment
13
- model_name = os.environ.get('MODEL_NAME', 'AvocadoMuffin/roberta-cuad-qa')
14
-
15
- # Login to HF Hub
16
- hf_token = os.environ.get('roberta_token')
17
- if hf_token:
18
- try:
19
- login(token=hf_token)
20
- print("✅ Logged into Hugging Face Hub")
21
- except Exception as e:
22
- print(f"⚠️ HF Hub login failed: {e}")
23
- print("⚠️ Model won't be pushed to Hub")
24
- hf_token = None
25
- else:
26
- print("⚠️ No roberta_token found - model won't be pushed to Hub")
27
-
28
- # Setup
29
- torch.cuda.empty_cache()
30
- device = "cuda" if torch.cuda.is_available() else "cpu"
31
- print(f"🔧 Using device: {device}")
32
-
33
- if torch.cuda.is_available():
34
- print(f"🎯 GPU: {torch.cuda.get_device_name()}")
35
- print(f"💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
36
-
37
- # Load and prepare data - OPTIMIZED SIZE FOR FASTER TRAINING
38
- print("📚 Loading CUAD dataset...")
39
- try:
40
- raw = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True)
41
- except Exception as e:
42
- print(f"❌ Failed to load dataset: {e}")
43
- return
44
-
45
- # Use 4000 samples for good model quality - expect ~1 hour training
46
- N = 5000 # Good balance of quality and reasonable training time
47
- raw = raw.shuffle(seed=42).select(range(min(N, len(raw))))
48
- ds = raw.train_test_split(test_size=0.1, seed=42)
49
- train_ds, val_ds = ds["train"], ds["test"]
50
-
51
- print(f"✅ Data loaded - Train: {len(train_ds)}, Val: {len(val_ds)}")
52
-
53
- # Store original validation data for metrics - CRITICAL FOR CORRECT EVALUATION
54
- print("📊 Preparing metrics data...")
55
- original_val_data = []
56
-
57
- # Store validation answers before tokenization
58
- for i, ex in enumerate(val_ds):
59
- original_val_data.append(ex["answers"])
60
-
61
- # Load model and tokenizer
62
- print("🤖 Loading RoBERTa model...")
63
- base_model = "roberta-base"
64
- try:
65
- tok = AutoTokenizer.from_pretrained(base_model, use_fast=True)
66
- model = AutoModelForQuestionAnswering.from_pretrained(base_model)
67
- except Exception as e:
68
- print(f"❌ Failed to load model/tokenizer: {e}")
69
- return
70
-
71
- # Add LoRA
72
- print("🔧 Adding LoRA adapters...")
73
- lora_cfg = LoraConfig(
74
  task_type=TaskType.QUESTION_ANS,
 
75
  target_modules=["query", "value"],
76
- r=16,
77
- lora_alpha=32,
78
- lora_dropout=0.05,
79
  )
80
- model = get_peft_model(model, lora_cfg)
81
  model.print_trainable_parameters()
82
- model.to(device)
83
-
84
- # Tokenization function - OPTIMIZED TO PREVENT EXCESSIVE EXPANSION
85
- max_len, doc_stride = 512, 400 # Large stride to minimize chunks per document
86
-
87
  def preprocess(examples):
88
- tok_batch = tok(
89
  examples["question"],
90
  examples["context"],
91
  truncation="only_second",
92
- max_length=max_len,
93
- stride=doc_stride,
94
  return_overflowing_tokens=True,
95
  return_offsets_mapping=True,
96
  padding="max_length",
97
- )
98
 
99
- sample_map = tok_batch.pop("overflow_to_sample_mapping")
100
- offset_map = tok_batch.pop("offset_mapping")
101
- start_pos, end_pos = [], []
102
-
103
- for i, offsets in enumerate(offset_map):
104
- cls_idx = tok_batch["input_ids"][i].index(tok.cls_token_id)
105
- sample_idx = sample_map[i]
106
- answer = examples["answers"][sample_idx]
107
-
108
- if len(answer["answer_start"]) == 0:
109
- start_pos.append(cls_idx)
110
- end_pos.append(cls_idx)
111
- continue
112
-
113
- s_char = answer["answer_start"][0]
114
- e_char = s_char + len(answer["text"][0])
115
- seq_ids = tok_batch.sequence_ids(i)
116
- c0, c1 = seq_ids.index(1), len(seq_ids) - 1 - seq_ids[::-1].index(1)
117
-
118
- if not (offsets[c0][0] <= s_char <= offsets[c1][1]):
119
- start_pos.append(cls_idx)
120
- end_pos.append(cls_idx)
121
- continue
122
-
123
- st = c0
124
- while st <= c1 and offsets[st][0] <= s_char:
125
- st += 1
126
- en = c1
127
- while en >= c0 and offsets[en][1] >= e_char:
128
- en -= 1
129
-
130
- # Fixed position calculation with bounds checking
131
- start_pos.append(max(c0, min(st - 1, c1)))
132
- end_pos.append(max(c0, min(en + 1, c1)))
133
-
134
- tok_batch["start_positions"] = start_pos
135
- tok_batch["end_positions"] = end_pos
136
- # Store sample mapping for metrics calculation
137
- tok_batch["sample_mapping"] = sample_map
138
- return tok_batch
139
-
140
- # Tokenize datasets
141
- print("🔄 Tokenizing datasets...")
142
- try:
143
- train_tok = train_ds.map(
144
- preprocess, batched=True, batch_size=50,
145
- remove_columns=train_ds.column_names,
146
- desc="Tokenizing train"
147
- )
148
- val_tok = val_ds.map(
149
- preprocess, batched=True, batch_size=50,
150
- remove_columns=val_ds.column_names,
151
- desc="Tokenizing validation"
152
- )
153
- except Exception as e:
154
- print(f"❌ Tokenization failed: {e}")
155
- return
156
-
157
- # DEBUG: Print actual dataset sizes after tokenization
158
- print(f"🔍 DEBUG INFO:")
159
- print(f" Original samples: {N}")
160
- print(f" After tokenization - Train: {len(train_tok)}, Val: {len(val_tok)}")
161
- print(f" Expansion factor: {len(train_tok)/len(train_ds):.1f}x")
162
-
163
- # SAFETY CHECK: If expansion is too high, reduce data size automatically
164
- expansion_factor = len(train_tok) / len(train_ds)
165
- if expansion_factor > 12: # Slightly more permissive for 4K samples
166
- print(f"⚠️ HIGH EXPANSION DETECTED ({expansion_factor:.1f}x)!")
167
- print("🔧 Auto-reducing dataset size to prevent excessively slow training...")
168
-
169
- # Allow up to 20k samples for 1 hour training
170
- target_size = min(20000, len(train_tok)) # Max 20k samples
171
- train_indices = list(range(0, len(train_tok), max(1, len(train_tok) // target_size)))[:target_size]
172
- val_indices = list(range(0, len(val_tok), max(1, len(val_tok) // (target_size // 10))))[:target_size // 10]
173
-
174
- train_tok = train_tok.select(train_indices)
175
- val_tok = val_tok.select(val_indices)
176
-
177
- print(f"✅ Reduced to - Train: {len(train_tok)}, Val: {len(val_tok)}")
178
- print(f"📈 This should complete in ~45-75 minutes")
179
-
180
- # Clean up memory
181
- del raw, ds, train_ds, val_ds
182
- gc.collect()
183
- torch.cuda.empty_cache()
184
-
185
- # FIXED: Metrics setup with proper error handling
186
- try:
187
- metric = evaluate.load("squad")
188
- except Exception as e:
189
- print(f"⚠️ Failed to load SQuAD metric: {e}")
190
- metric = None
191
-
192
- def compute_metrics(eval_pred):
193
- if metric is None:
194
- print("⚠️ No metric available, returning dummy scores")
195
- return {"exact_match": 0.0, "f1": 0.0}
196
-
197
- try:
198
- preds, _ = eval_pred
199
- starts, ends = preds
200
-
201
- # Group predictions by original sample (handle multiple chunks per sample)
202
- sample_predictions = {}
203
-
204
- for i in range(len(starts)):
205
- # FIXED: Proper dictionary access without hasattr
206
- if 'sample_mapping' in val_tok[i]:
207
- orig_idx = val_tok[i]['sample_mapping']
208
- else:
209
- # Fallback: assume 1:1 mapping (may be inaccurate with chunking)
210
- orig_idx = min(i, len(original_val_data) - 1)
211
-
212
- # Get best answer span for this chunk
213
- start_idx = int(np.argmax(starts[i]))
214
- end_idx = int(np.argmax(ends[i]))
215
- if start_idx > end_idx:
216
- start_idx, end_idx = end_idx, start_idx
217
-
218
- # Extract answer text
219
- try:
220
- answer_text = tok.decode(
221
- val_tok[i]["input_ids"][start_idx:end_idx+1],
222
- skip_special_tokens=True
223
- ).strip()
224
- except Exception:
225
- answer_text = ""
226
-
227
- # Store best prediction for this original sample
228
- confidence = float(starts[i][start_idx]) + float(ends[i][end_idx])
229
- if orig_idx not in sample_predictions or confidence > sample_predictions[orig_idx][1]:
230
- sample_predictions[orig_idx] = (answer_text, confidence)
231
-
232
- # Format for SQuAD metric
233
- predictions = []
234
- references = []
235
-
236
- for orig_idx in range(len(original_val_data)):
237
- pred_text = sample_predictions.get(orig_idx, ("", 0))[0]
238
- predictions.append({
239
- "id": str(orig_idx),
240
- "prediction_text": pred_text
241
- })
242
- references.append({
243
- "id": str(orig_idx),
244
- "answers": original_val_data[orig_idx]
245
- })
246
-
247
- result = metric.compute(predictions=predictions, references=references)
248
-
249
- # Add some debugging info
250
- print(f"📊 Evaluation: EM={result['exact_match']:.3f}, F1={result['f1']:.3f}")
251
- return result
252
-
253
- except Exception as e:
254
- print(f"⚠️ Metrics computation failed: {e}")
255
- print(f" Pred shape: {np.array(preds).shape if preds else 'None'}")
256
- print(f" Val dataset size: {len(val_tok)}")
257
- print(f" Original val size: {len(original_val_data)}")
258
- return {"exact_match": 0.0, "f1": 0.0}
259
-
260
- # OPTIMIZED Training arguments
261
- output_dir = "./model_output"
262
  args = TrainingArguments(
263
- output_dir=output_dir,
264
- per_device_train_batch_size=8, # INCREASED from 2
265
- per_device_eval_batch_size=8, # INCREASED from 4
266
- gradient_accumulation_steps=2, # REDUCED from 8
267
- num_train_epochs=3, # 3 epochs for good training
268
- learning_rate=5e-4,
269
- lr_scheduler_type="cosine",
270
- warmup_ratio=0.1,
271
- bf16=True, # Better for newer GPUs
272
- eval_strategy="steps",
273
- eval_steps=100, # More frequent evaluation
274
- save_steps=200, # More frequent saving
275
  save_total_limit=2,
276
- logging_steps=25, # More frequent logging
277
  weight_decay=0.01,
278
- remove_unused_columns=True,
279
- report_to=None,
280
- push_to_hub=False, # We'll push manually
281
- dataloader_pin_memory=True, # Faster data loading
282
- dataloader_num_workers=4, # Parallel data loading
283
- gradient_checkpointing=False, # Trade memory for speed
284
- load_best_model_at_end=True, # Load best model
285
- metric_for_best_model="f1", # Use F1 score
286
  greater_is_better=True,
 
 
287
  )
288
 
289
- # Create trainer
290
- trainer = Trainer(
291
  model=model,
292
  args=args,
293
- train_dataset=train_tok,
294
- eval_dataset=val_tok,
295
  tokenizer=tok,
296
  data_collator=default_data_collator,
297
  compute_metrics=compute_metrics,
298
  )
299
 
300
- print(f"🚀 Starting training...")
301
- print(f"📊 Total training samples: {len(train_tok)}")
302
- print(f"📊 Total validation samples: {len(val_tok)}")
303
- print(f" Effective batch size: {args.per_device_train_batch_size * args.gradient_accumulation_steps}")
304
-
305
- if torch.cuda.is_available():
306
- print(f"💾 GPU memory before training: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
307
-
308
- # Training loop with error handling
309
- try:
310
- trainer.train()
311
- print(" Training completed successfully!")
312
-
313
- except RuntimeError as e:
314
- if "CUDA out of memory" in str(e):
315
- print("⚠️ GPU OOM - reducing batch size and retrying...")
316
- torch.cuda.empty_cache()
317
- gc.collect()
318
-
319
- # Reduce batch size
320
- args.per_device_train_batch_size = 4
321
- args.gradient_accumulation_steps = 4
322
-
323
- trainer = Trainer(
324
- model=model, args=args,
325
- train_dataset=train_tok, eval_dataset=val_tok,
326
- tokenizer=tok, data_collator=default_data_collator,
327
- compute_metrics=compute_metrics,
328
- )
329
- trainer.train()
330
- print("✅ Training completed with reduced batch size!")
331
- else:
332
- print(f"❌ Training failed: {e}")
333
- raise e
334
- except Exception as e:
335
- print(f"❌ Unexpected training error: {e}")
336
- return
337
-
338
- # Save model locally first
339
- print("💾 Saving model locally...")
340
- try:
341
- os.makedirs(output_dir, exist_ok=True)
342
- trainer.model.save_pretrained(output_dir)
343
- tok.save_pretrained(output_dir)
344
- print("✅ Model saved locally")
345
- except Exception as e:
346
- print(f"❌ Failed to save model locally: {e}")
347
- return
348
-
349
- # Save training info
350
- training_info = {
351
- "model_name": model_name,
352
- "base_model": base_model,
353
- "dataset": "theatticusproject/cuad-qa",
354
- "original_samples": N,
355
- "training_samples_after_tokenization": len(train_tok),
356
- "validation_samples_after_tokenization": len(val_tok),
357
- "lora_config": {
358
- "r": lora_cfg.r,
359
- "lora_alpha": lora_cfg.lora_alpha,
360
- "target_modules": lora_cfg.target_modules,
361
- "lora_dropout": lora_cfg.lora_dropout,
362
- },
363
- "training_args": {
364
- "batch_size": args.per_device_train_batch_size,
365
- "gradient_accumulation_steps": args.gradient_accumulation_steps,
366
- "effective_batch_size": args.per_device_train_batch_size * args.gradient_accumulation_steps,
367
- "epochs": args.num_train_epochs,
368
- "learning_rate": args.learning_rate,
369
- }
370
- }
371
-
372
- try:
373
- with open(os.path.join(output_dir, "training_info.json"), "w") as f:
374
- json.dump(training_info, f, indent=2)
375
- except Exception as e:
376
- print(f"⚠️ Failed to save training info: {e}")
377
-
378
- # Push to Hub if token available
379
- if hf_token:
380
- try:
381
- print(f"⬆️ Pushing model to Hub: {model_name}")
382
- trainer.model.push_to_hub(model_name, private=False)
383
- tok.push_to_hub(model_name, private=False)
384
-
385
- # Also push training info
386
- try:
387
- from huggingface_hub import upload_file
388
- upload_file(
389
- path_or_fileobj=os.path.join(output_dir, "training_info.json"),
390
- path_in_repo="training_info.json",
391
- repo_id=model_name,
392
- repo_type="model"
393
- )
394
- print("📊 Training info uploaded")
395
- except Exception as e:
396
- print(f"⚠️ Training info upload failed: {e}")
397
-
398
- print(f"🎉 Model successfully saved to: https://huggingface.co/{model_name}")
399
-
400
- except Exception as e:
401
- print(f"❌ Failed to push to Hub: {e}")
402
- print("💾 Model saved locally in ./model_output/")
403
- else:
404
- print("💾 Model saved locally in ./model_output/ (no HF token for Hub upload)")
405
-
406
- print("🏁 Training pipeline completed!")
407
 
408
  if __name__ == "__main__":
409
- main()
 
1
+ #!/usr/bin/env python
2
+ # train_cuad_lora.py
3
+ """
4
+ CUAD fine-tune with LoRA on an L4 / T4 GPU.
5
+ Expected wall-clock on Nvidia L4: ~25-30 min.
6
+ """
7
+
8
+ import os, json, random, gc
9
+ from collections import defaultdict
10
+
11
+ import torch, numpy as np
12
+ from datasets import load_dataset, Dataset, disable_caching
13
  from transformers import (
14
  AutoTokenizer, AutoModelForQuestionAnswering,
15
+ TrainingArguments, default_data_collator
16
  )
17
+ from transformers import QuestionAnsweringTrainer, EvalPrediction
18
  from peft import LoraConfig, get_peft_model, TaskType
19
+ import evaluate
20
  from huggingface_hub import login
21
+
22
+ disable_caching() # avoids giant disk cache on Colab
23
+
24
+ # ─────────────────────────────────────────────────────────────── helpers ──
25
+
26
+ MAX_LEN = 384 # window
27
+ DOC_STRIDE = 128
28
+ SEED = 42
29
+
30
+ def set_seed(seed):
31
+ random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
32
+ torch.cuda.manual_seed_all(seed)
33
+
34
+ def balance_has_answer(dataset, ratio=2.0):
35
+ """Keep all has-answer rows, down-sample no-answer rows to `ratio`."""
36
+ has, no = [], []
37
+ for ex in dataset:
38
+ (has if ex["answers"]["text"] else no).append(ex)
39
+ k = int(len(has) * ratio)
40
+ no = random.sample(no, min(k, len(no)))
41
+ return Dataset.from_list(has + no)
42
+
43
+ # ────────────────────────────────────────────────────────────── postproc ──
44
+
45
+ metric = evaluate.load("squad")
46
+
47
+ def postprocess_qa(examples, features, raw_predictions, tokenizer):
48
+ """HF-style span extraction + n-best, returns SQuAD format dict."""
49
+ all_start, all_end = raw_predictions
50
+ example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
51
+ features_per_example = defaultdict(list)
52
+ for i, feat_id in enumerate(features["example_id"]):
53
+ features_per_example[example_id_to_index[feat_id]].append(i)
54
+
55
+ predictions = []
56
+
57
+ for example_idx, example in enumerate(examples):
58
+ best_score = -1e9
59
+ best_span = ""
60
+ context = example["context"]
61
+
62
+ for feat_idx in features_per_example[example_idx]:
63
+ start_logit = all_start[feat_idx]
64
+ end_logit = all_end[feat_idx]
65
+ offset = features["offset_mapping"][feat_idx]
66
+
67
+ start_idx = int(np.argmax(start_logit))
68
+ end_idx = int(np.argmax(end_logit))
69
+
70
+ if start_idx <= end_idx < len(offset):
71
+ start_char, _ = offset[start_idx]
72
+ _, end_char = offset[end_idx]
73
+ span = context[start_char:end_char].strip()
74
+ score = start_logit[start_idx] + end_logit[end_idx]
75
+ if score > best_score and span:
76
+ best_score, best_span = score, span
77
+
78
+ predictions.append(
79
+ {"id": example["id"], "prediction_text": best_span}
80
+ )
81
+ return predictions
82
+
83
+ def compute_metrics(eval_pred: EvalPrediction):
84
+ predictions = postprocess_qa(raw_val, val_feats, eval_pred.predictions, tok)
85
+ references = [
86
+ {"id": ex["id"], "answers": ex["answers"]} for ex in raw_val
87
+ ]
88
+ return metric.compute(predictions=predictions, references=references)
89
+
90
+ # ───────────────────────────────────────────────────────────────── main ──
91
 
92
  def main():
93
+ set_seed(SEED)
94
+
95
+ #  model name to store on Hub
96
+ model_repo = os.getenv("MODEL_NAME", "AvocadoMuffin/roberta-cuad-qa-v2")
97
+
98
+ if (tokn := os.getenv("roberta_token")):
99
+ try: login(tokn); print("🔑 HuggingFace Hub login OK")
100
+ except Exception as e: print("Hub login failed:", e)
101
+
102
+ print("📚 Loading CUAD…")
103
+ cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True)
104
+ cuad = cuad.shuffle(seed=SEED)
105
+ cuad = balance_has_answer(cuad, ratio=2.0) # ≈18 k rows
106
+
107
+ # train / val 90-10
108
+ ds = cuad.train_test_split(test_size=0.1, seed=SEED)
109
+ train_raw, val_raw = ds["train"], ds["test"]
110
+
111
+ # ── tokeniser & model (SQuAD-2 tuned) ───────────────────────────────
112
+ base_ckpt = "deepset/roberta-base-squad2"
113
+ tok = AutoTokenizer.from_pretrained(base_ckpt, use_fast=True)
114
+ model = AutoModelForQuestionAnswering.from_pretrained(base_ckpt)
115
+
116
+ # LoRA
117
+ lora = LoraConfig(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  task_type=TaskType.QUESTION_ANS,
119
+ r=16, lora_alpha=32, lora_dropout=0.05,
120
  target_modules=["query", "value"],
 
 
 
121
  )
122
+ model = get_peft_model(model, lora)
123
  model.print_trainable_parameters()
124
+
125
+ # ── preprocess ─────────────────────────────────────────────────────
 
 
 
126
  def preprocess(examples):
127
+ return tok(
128
  examples["question"],
129
  examples["context"],
130
  truncation="only_second",
131
+ max_length=MAX_LEN,
132
+ stride=DOC_STRIDE,
133
  return_overflowing_tokens=True,
134
  return_offsets_mapping=True,
135
  padding="max_length",
136
+ ) | { "example_id": examples["id"] }
137
 
138
+ train_feats = train_raw.map(
139
+ preprocess, batched=True, remove_columns=train_raw.column_names,
140
+ num_proc=4, desc="tokenise-train"
141
+ )
142
+ val_feats = val_raw.map(
143
+ preprocess, batched=True, remove_columns=val_raw.column_names,
144
+ num_proc=4, desc="tokenise-val"
145
+ )
146
+
147
+ global raw_val # for metric fn
148
+ raw_val = val_raw
149
+
150
+ # ── training args ──────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  args = TrainingArguments(
152
+ output_dir="./cuad_lora_out",
153
+ learning_rate=3e-5,
154
+ num_train_epochs=4,
155
+ per_device_train_batch_size=8,
156
+ per_device_eval_batch_size=8,
157
+ gradient_accumulation_steps=4, # eff. BS 32
158
+ fp16=False, bf16=True, # L4 = bf16
159
+ evaluation_strategy="steps",
160
+ eval_steps=250,
161
+ save_steps=500,
 
 
162
  save_total_limit=2,
 
163
  weight_decay=0.01,
164
+ lr_scheduler_type="cosine",
165
+ warmup_ratio=0.1,
166
+ load_best_model_at_end=True,
167
+ metric_for_best_model="f1",
 
 
 
 
168
  greater_is_better=True,
169
+ logging_steps=50,
170
+ report_to="none",
171
  )
172
 
173
+ trainer = QuestionAnsweringTrainer(
 
174
  model=model,
175
  args=args,
176
+ train_dataset=train_feats,
177
+ eval_dataset=val_feats,
178
  tokenizer=tok,
179
  data_collator=default_data_collator,
180
  compute_metrics=compute_metrics,
181
  )
182
 
183
+ print("🚀 Training…")
184
+ trainer.train()
185
+
186
+ print(" Done. Best F1:", trainer.state.best_metric)
187
+ trainer.save_model("./cuad_lora_out")
188
+ tok.save_pretrained("./cuad_lora_out")
189
+
190
+ # optional: push
191
+ if tokn:
192
+ trainer.push_to_hub(model_repo, private=False)
193
+ tok.push_to_hub(model_repo, private=False)
194
+ print("🚀 Pushed to:", f"https://huggingface.co/{model_repo}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  if __name__ == "__main__":
197
+ main()