AvocadoMuffin commited on
Commit
c48cc67
Β·
verified Β·
1 Parent(s): 22050d0

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +217 -137
train.py CHANGED
@@ -1,9 +1,8 @@
1
  #!/usr/bin/env python
2
  # train_cuad_lora_improved.py
3
  """
4
- CUAD fine-tune with LoRA on an L4 / T4 GPU.
5
- Improved version with better error handling and recovery mechanisms.
6
- Expected wall-clock on Nvidia L4: ~25-30 min.
7
  """
8
 
9
  import os, json, random, gc, time
@@ -20,14 +19,15 @@ from peft import LoraConfig, get_peft_model, TaskType
20
  import evaluate
21
  from huggingface_hub import login
22
 
23
- disable_caching() # avoids giant disk cache on Colab
24
 
25
  # ─────────────────────────────────────────────────────────────── helpers ──
26
 
27
- MAX_LEN = 384 # window
28
  DOC_STRIDE = 128
29
  SEED = 42
30
  CHECKPOINT_DIR = "./cuad_lora_checkpoints"
 
31
 
32
  def set_seed(seed):
33
  random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
@@ -43,9 +43,42 @@ def load_checkpoint(checkpoint_path):
43
  """Load preprocessing checkpoint from disk"""
44
  if os.path.exists(checkpoint_path):
45
  print(f"πŸ“‚ Loading checkpoint: {checkpoint_path}")
46
- return torch.load(checkpoint_path)
47
  return None
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def balance_has_answer(dataset, ratio=2.0):
50
  """Keep all has-answer rows, down-sample no-answer rows to `ratio`."""
51
  has, no = [], []
@@ -105,112 +138,163 @@ def compute_metrics(eval_pred):
105
 
106
  # ───────────────────────────────────────────────────────────────── main ──
107
 
108
- def preprocess_with_retry(dataset, dataset_name, max_retries=3):
109
- """Preprocess dataset with retry logic and checkpointing"""
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- def preprocess(examples):
112
- # Tokenize all at once
113
- tokenized = tok(
114
- examples["question"],
115
- examples["context"],
116
- truncation="only_second",
117
- max_length=MAX_LEN,
118
- stride=DOC_STRIDE,
119
- return_overflowing_tokens=True,
120
- return_offsets_mapping=True,
121
- padding="max_length",
122
- )
123
-
124
- sample_mapping = tokenized.pop("overflow_to_sample_mapping")
125
- offset_mapping = tokenized["offset_mapping"]
126
 
127
- # Vectorized processing
128
- start_positions = []
129
- end_positions = []
130
- example_ids = []
131
 
132
- for i in range(len(tokenized["input_ids"])):
133
- sample_idx = sample_mapping[i]
134
- answers = examples["answers"][sample_idx]
135
- offsets = offset_mapping[i]
136
-
137
- # Find CLS token position (always 0 for RoBERTa)
138
- cls_index = 0
139
-
140
- example_ids.append(examples["id"][sample_idx])
141
-
142
- # No answer case
143
- if not answers["text"] or not answers["text"][0]:
144
- start_positions.append(cls_index)
145
- end_positions.append(cls_index)
146
- continue
147
-
148
- # Get answer span
149
- answer_start = answers["answer_start"][0]
150
- answer_text = answers["text"][0]
151
- answer_end = answer_start + len(answer_text)
152
-
153
- # Find token positions
154
- start_token = end_token = cls_index
155
 
156
- for tok_idx, (start_char, end_char) in enumerate(offsets):
157
- if start_char <= answer_start < end_char:
158
- start_token = tok_idx
159
- if start_char < answer_end <= end_char:
160
- end_token = tok_idx
161
- break
162
-
163
- # Ensure valid span
164
- if start_token <= end_token and start_token > 0:
165
- start_positions.append(start_token)
166
- end_positions.append(end_token)
167
- else:
168
- start_positions.append(cls_index)
169
- end_positions.append(cls_index)
170
 
171
- tokenized["start_positions"] = start_positions
172
- tokenized["end_positions"] = end_positions
173
- tokenized["example_id"] = example_ids
174
- return tokenized
 
 
 
 
175
 
176
- checkpoint_path = f"{CHECKPOINT_DIR}/{dataset_name}_features.pt"
 
177
 
178
- # Try to load from checkpoint first
179
- features = load_checkpoint(checkpoint_path)
180
- if features is not None:
181
- print(f"βœ… Loaded {dataset_name} features from checkpoint")
182
- return Dataset.from_dict(features)
 
183
 
184
- # Process with retries
185
- for attempt in range(max_retries):
186
- try:
187
- print(f"πŸ”„ Preprocessing {dataset_name} data (attempt {attempt + 1}/{max_retries})...")
188
-
189
- # Use smaller batch sizes and reduce num_proc for stability
190
- features = dataset.map(
191
- preprocess,
192
- batched=True,
193
- remove_columns=dataset.column_names,
194
- num_proc=2, # Reduced from 4 for stability
195
- desc=f"tokenise-{dataset_name}",
196
- load_from_cache_file=False,
197
- batch_size=50, # Reduced from 100 for stability
198
- writer_batch_size=50 # Add writer batch size limit
199
- )
200
-
201
- # Save checkpoint after successful processing
202
- save_checkpoint(features.to_dict(), checkpoint_path)
203
- return features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- except Exception as e:
206
- print(f"❌ Preprocessing failed on attempt {attempt + 1}: {e}")
207
- if attempt < max_retries - 1:
208
- print(f"⏳ Waiting 10 seconds before retry...")
209
- time.sleep(10)
210
- gc.collect() # Clean up memory
211
- else:
212
- print("πŸ’₯ All preprocessing attempts failed!")
213
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  def main():
216
  global val_raw, val_feats, tok
@@ -229,8 +313,7 @@ def main():
229
  print("πŸ”‘ HuggingFace Hub login OK")
230
  except Exception as e:
231
  print(f"⚠️ Hub login failed: {e}")
232
- print("πŸ“ Training will continue but won't push to Hub")
233
- tokn = None # Disable pushing
234
 
235
  print("πŸ“š Loading CUAD…")
236
  dataset_checkpoint = f"{CHECKPOINT_DIR}/cuad_dataset.pt"
@@ -241,17 +324,15 @@ def main():
241
  cuad = Dataset.from_dict(dataset_data)
242
  print(f"βœ… Loaded dataset from checkpoint: {len(cuad)} examples")
243
  else:
244
- # Load and process dataset
245
  try:
246
  cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True)
247
  print(f"βœ… Loaded {len(cuad)} examples")
248
  except Exception as e:
249
  print(f"❌ Dataset loading failed: {e}")
250
- print("πŸ”„ Retrying with cache disabled...")
251
  cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True, download_mode="force_redownload")
252
 
253
  cuad = cuad.shuffle(seed=SEED)
254
- cuad = balance_has_answer(cuad, ratio=2.0) # β‰ˆ18 k rows
255
  print(f"πŸ“Š Balanced dataset: {len(cuad)} examples")
256
 
257
  # Save dataset checkpoint
@@ -263,7 +344,7 @@ def main():
263
 
264
  # ── tokeniser & model (SQuAD-2 tuned) ───────────────────────────────
265
  base_ckpt = "deepset/roberta-base-squad2"
266
- tok = AutoTokenizer.from_pretrained(base_ckpt, use_fast=True)
267
  model = AutoModelForQuestionAnswering.from_pretrained(base_ckpt)
268
 
269
  # LoRA
@@ -275,24 +356,30 @@ def main():
275
  model = get_peft_model(model, lora)
276
  model.print_trainable_parameters()
277
 
278
- # ── preprocess with retry logic ─────────────────────────────────────────
279
- train_feats = preprocess_with_retry(train_raw, "train")
280
- # Remove offset_mapping from training data (not needed during training)
 
 
281
  if "offset_mapping" in train_feats.column_names:
282
  train_feats = train_feats.remove_columns(["offset_mapping"])
283
 
284
- val_feats = preprocess_with_retry(val_raw, "val")
285
- # Keep offset_mapping for validation (needed for postprocessing)
 
 
 
 
286
 
287
  # ── training args ──────────────────────────────────────────────────
288
  args = TrainingArguments(
289
  output_dir="./cuad_lora_out",
290
  learning_rate=3e-5,
291
  num_train_epochs=4,
292
- per_device_train_batch_size=8,
293
- per_device_eval_batch_size=8,
294
- gradient_accumulation_steps=4, # eff. BS 32
295
- fp16=False, bf16=True, # L4 = bf16
296
  eval_strategy="steps",
297
  eval_steps=250,
298
  save_steps=500,
@@ -305,11 +392,9 @@ def main():
305
  greater_is_better=True,
306
  logging_steps=50,
307
  report_to="none",
308
- # Add resume from checkpoint capability
309
  resume_from_checkpoint=True,
310
- # Add dataloader settings for stability
311
- dataloader_num_workers=0, # Disable multiprocessing for data loading
312
- dataloader_pin_memory=False, # Reduce memory pressure
313
  )
314
 
315
  trainer = Trainer(
@@ -328,46 +413,41 @@ def main():
328
  print("βœ… Training completed successfully!")
329
  except Exception as e:
330
  print(f"❌ Training failed: {e}")
331
- print("πŸ’Ύ Attempting to save current state...")
332
  try:
333
  trainer.save_model("./cuad_lora_out_partial")
334
  tok.save_pretrained("./cuad_lora_out_partial")
335
- print("πŸ’Ύ Partial model saved to ./cuad_lora_out_partial")
336
  except:
337
  print("❌ Could not save partial model")
338
  raise e
339
 
340
- print("βœ… Done. Best F1:", trainer.state.best_metric)
341
  trainer.save_model("./cuad_lora_out")
342
  tok.save_pretrained("./cuad_lora_out")
343
 
344
- # optional: push (with retry logic)
345
  if tokn:
346
- max_push_retries = 3
347
- for push_attempt in range(max_push_retries):
348
  try:
349
- print(f"⬆️ Pushing to Hub (attempt {push_attempt + 1}/{max_push_retries})...")
350
  trainer.push_to_hub(model_repo, private=False)
351
  tok.push_to_hub(model_repo, private=False)
352
  print("πŸš€ Pushed to:", f"https://huggingface.co/{model_repo}")
353
  break
354
  except Exception as e:
355
- print(f"⚠️ Hub push failed on attempt {push_attempt + 1}: {e}")
356
- if push_attempt < max_push_retries - 1:
357
- print("⏳ Waiting 30 seconds before retry...")
358
  time.sleep(30)
359
  else:
360
- print("πŸ’Ύ Model saved locally in ./cuad_lora_out (push failed)")
361
- else:
362
- print("πŸ’Ύ Model saved locally in ./cuad_lora_out (no HF token for push)")
363
 
364
- # Clean up checkpoints after successful completion
365
  try:
366
  import shutil
367
  shutil.rmtree(CHECKPOINT_DIR)
368
  print("🧹 Cleaned up temporary checkpoints")
369
  except:
370
- print("⚠️ Could not clean up temporary checkpoints")
371
 
372
  if __name__ == "__main__":
373
  main()
 
1
  #!/usr/bin/env python
2
  # train_cuad_lora_improved.py
3
  """
4
+ CUAD fine-tune with LoRA on L40S GPU in HuggingFace Spaces.
5
+ Improved version with better error handling and chunked processing.
 
6
  """
7
 
8
  import os, json, random, gc, time
 
19
  import evaluate
20
  from huggingface_hub import login
21
 
22
+ disable_caching()
23
 
24
  # ─────────────────────────────────────────────────────────────── helpers ──
25
 
26
+ MAX_LEN = 384
27
  DOC_STRIDE = 128
28
  SEED = 42
29
  CHECKPOINT_DIR = "./cuad_lora_checkpoints"
30
+ CHUNK_SIZE = 100 # Process in smaller chunks to avoid timeouts
31
 
32
  def set_seed(seed):
33
  random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
 
43
  """Load preprocessing checkpoint from disk"""
44
  if os.path.exists(checkpoint_path):
45
  print(f"πŸ“‚ Loading checkpoint: {checkpoint_path}")
46
+ return torch.load(checkpoint_path, map_location='cpu')
47
  return None
48
 
49
+ def save_partial_features(features_dict, chunk_idx, dataset_name):
50
+ """Save partial features for a chunk"""
51
+ partial_path = f"{CHECKPOINT_DIR}/{dataset_name}_chunk_{chunk_idx:04d}.pt"
52
+ save_checkpoint(features_dict, partial_path)
53
+ return partial_path
54
+
55
+ def load_and_combine_chunks(dataset_name):
56
+ """Load all chunk files and combine them"""
57
+ chunk_files = []
58
+ if os.path.exists(CHECKPOINT_DIR):
59
+ for f in os.listdir(CHECKPOINT_DIR):
60
+ if f.startswith(f"{dataset_name}_chunk_") and f.endswith('.pt'):
61
+ chunk_files.append(os.path.join(CHECKPOINT_DIR, f))
62
+
63
+ if not chunk_files:
64
+ return None
65
+
66
+ chunk_files.sort()
67
+ print(f"πŸ“‚ Found {len(chunk_files)} chunks for {dataset_name}")
68
+
69
+ # Combine all chunks
70
+ combined = None
71
+ for chunk_file in chunk_files:
72
+ chunk_data = torch.load(chunk_file, map_location='cpu')
73
+ if combined is None:
74
+ combined = chunk_data
75
+ else:
76
+ for key in chunk_data:
77
+ combined[key].extend(chunk_data[key])
78
+
79
+ print(f"βœ… Combined {len(combined['input_ids'])} features from chunks")
80
+ return combined
81
+
82
  def balance_has_answer(dataset, ratio=2.0):
83
  """Keep all has-answer rows, down-sample no-answer rows to `ratio`."""
84
  has, no = [], []
 
138
 
139
  # ───────────────────────────────────────────────────────────────── main ──
140
 
141
+ def preprocess_single_example(example, tokenizer):
142
+ """Process a single example to avoid batch processing issues"""
143
+ # Tokenize
144
+ tokenized = tokenizer(
145
+ example["question"],
146
+ example["context"],
147
+ truncation="only_second",
148
+ max_length=MAX_LEN,
149
+ stride=DOC_STRIDE,
150
+ return_overflowing_tokens=True,
151
+ return_offsets_mapping=True,
152
+ padding="max_length",
153
+ )
154
 
155
+ results = {
156
+ "input_ids": [],
157
+ "attention_mask": [],
158
+ "start_positions": [],
159
+ "end_positions": [],
160
+ "example_id": [],
161
+ "offset_mapping": []
162
+ }
163
+
164
+ for i in range(len(tokenized["input_ids"])):
165
+ results["input_ids"].append(tokenized["input_ids"][i])
166
+ results["attention_mask"].append(tokenized["attention_mask"][i])
167
+ results["offset_mapping"].append(tokenized["offset_mapping"][i])
168
+ results["example_id"].append(example["id"])
 
169
 
170
+ # Handle answer positions
171
+ answers = example["answers"]
172
+ offsets = tokenized["offset_mapping"][i]
173
+ cls_index = 0
174
 
175
+ if not answers["text"] or not answers["text"][0]:
176
+ results["start_positions"].append(cls_index)
177
+ results["end_positions"].append(cls_index)
178
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
+ answer_start = answers["answer_start"][0]
181
+ answer_text = answers["text"][0]
182
+ answer_end = answer_start + len(answer_text)
183
+
184
+ start_token = end_token = cls_index
185
+
186
+ for tok_idx, (start_char, end_char) in enumerate(offsets):
187
+ if start_char <= answer_start < end_char:
188
+ start_token = tok_idx
189
+ if start_char < answer_end <= end_char:
190
+ end_token = tok_idx
191
+ break
 
 
192
 
193
+ if start_token <= end_token and start_token > 0:
194
+ results["start_positions"].append(start_token)
195
+ results["end_positions"].append(end_token)
196
+ else:
197
+ results["start_positions"].append(cls_index)
198
+ results["end_positions"].append(cls_index)
199
+
200
+ return results
201
 
202
+ def preprocess_with_chunking(dataset, dataset_name, tokenizer):
203
+ """Process dataset in chunks to avoid timeouts"""
204
 
205
+ # Check if final result already exists
206
+ final_checkpoint = f"{CHECKPOINT_DIR}/{dataset_name}_features.pt"
207
+ final_features = load_checkpoint(final_checkpoint)
208
+ if final_features is not None:
209
+ print(f"βœ… Loaded {dataset_name} features from final checkpoint")
210
+ return Dataset.from_dict(final_features)
211
 
212
+ # Check if we can resume from chunks
213
+ combined_features = load_and_combine_chunks(dataset_name)
214
+ if combined_features is not None:
215
+ # Save as final checkpoint
216
+ save_checkpoint(combined_features, final_checkpoint)
217
+ return Dataset.from_dict(combined_features)
218
+
219
+ # Process in chunks
220
+ print(f"πŸ”„ Processing {dataset_name} dataset in chunks of {CHUNK_SIZE}...")
221
+
222
+ total_samples = len(dataset)
223
+ num_chunks = (total_samples + CHUNK_SIZE - 1) // CHUNK_SIZE
224
+
225
+ for chunk_idx in range(num_chunks):
226
+ chunk_file = f"{CHECKPOINT_DIR}/{dataset_name}_chunk_{chunk_idx:04d}.pt"
227
+
228
+ # Skip if chunk already processed
229
+ if os.path.exists(chunk_file):
230
+ print(f"⏭️ Chunk {chunk_idx + 1}/{num_chunks} already exists, skipping...")
231
+ continue
232
+
233
+ start_idx = chunk_idx * CHUNK_SIZE
234
+ end_idx = min(start_idx + CHUNK_SIZE, total_samples)
235
+
236
+ print(f"πŸ”„ Processing chunk {chunk_idx + 1}/{num_chunks} (samples {start_idx}-{end_idx-1})...")
237
+
238
+ chunk_results = {
239
+ "input_ids": [],
240
+ "attention_mask": [],
241
+ "start_positions": [],
242
+ "end_positions": [],
243
+ "example_id": [],
244
+ "offset_mapping": []
245
+ }
246
+
247
+ # Process each example in the chunk individually
248
+ for i in range(start_idx, end_idx):
249
+ if i % 10 == 0: # Progress indicator
250
+ print(f" Processing sample {i}/{total_samples}")
251
 
252
+ try:
253
+ example = dataset[i]
254
+ result = preprocess_single_example(example, tokenizer)
255
+
256
+ # Add to chunk results
257
+ for key in chunk_results:
258
+ chunk_results[key].extend(result[key])
259
+
260
+ except Exception as e:
261
+ print(f"⚠️ Error processing sample {i}: {e}")
262
+ continue
263
+
264
+ # Save chunk
265
+ save_partial_features(chunk_results, chunk_idx, dataset_name)
266
+
267
+ # Clean up memory
268
+ del chunk_results
269
+ gc.collect()
270
+
271
+ print(f"βœ… Chunk {chunk_idx + 1}/{num_chunks} completed and saved")
272
+
273
+ # Combine all chunks
274
+ print("πŸ”„ Combining all chunks...")
275
+ combined_features = load_and_combine_chunks(dataset_name)
276
+
277
+ if combined_features is None:
278
+ raise RuntimeError("Failed to load and combine chunks!")
279
+
280
+ # Save final result
281
+ save_checkpoint(combined_features, final_checkpoint)
282
+
283
+ # Clean up chunk files
284
+ cleanup_chunk_files(dataset_name)
285
+
286
+ return Dataset.from_dict(combined_features)
287
+
288
+ def cleanup_chunk_files(dataset_name):
289
+ """Remove chunk files after successful combination"""
290
+ if os.path.exists(CHECKPOINT_DIR):
291
+ for f in os.listdir(CHECKPOINT_DIR):
292
+ if f.startswith(f"{dataset_name}_chunk_") and f.endswith('.pt'):
293
+ try:
294
+ os.remove(os.path.join(CHECKPOINT_DIR, f))
295
+ except:
296
+ pass
297
+ print(f"🧹 Cleaned up chunk files for {dataset_name}")
298
 
299
  def main():
300
  global val_raw, val_feats, tok
 
313
  print("πŸ”‘ HuggingFace Hub login OK")
314
  except Exception as e:
315
  print(f"⚠️ Hub login failed: {e}")
316
+ tokn = None
 
317
 
318
  print("πŸ“š Loading CUAD…")
319
  dataset_checkpoint = f"{CHECKPOINT_DIR}/cuad_dataset.pt"
 
324
  cuad = Dataset.from_dict(dataset_data)
325
  print(f"βœ… Loaded dataset from checkpoint: {len(cuad)} examples")
326
  else:
 
327
  try:
328
  cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True)
329
  print(f"βœ… Loaded {len(cuad)} examples")
330
  except Exception as e:
331
  print(f"❌ Dataset loading failed: {e}")
 
332
  cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True, download_mode="force_redownload")
333
 
334
  cuad = cuad.shuffle(seed=SEED)
335
+ cuad = balance_has_answer(cuad, ratio=2.0)
336
  print(f"πŸ“Š Balanced dataset: {len(cuad)} examples")
337
 
338
  # Save dataset checkpoint
 
344
 
345
  # ── tokeniser & model (SQuAD-2 tuned) ───────────────────────────────
346
  base_ckpt = "deepset/roberta-base-squad2"
347
+ tok = AutoTokenizer.from_pretrained(base_ckpt, use_fast=True)
348
  model = AutoModelForQuestionAnswering.from_pretrained(base_ckpt)
349
 
350
  # LoRA
 
356
  model = get_peft_model(model, lora)
357
  model.print_trainable_parameters()
358
 
359
+ # ── preprocess with chunking ─────────────────────────────────────────
360
+ print("πŸ”„ Starting preprocessing...")
361
+
362
+ train_feats = preprocess_with_chunking(train_raw, "train", tok)
363
+ # Remove offset_mapping from training data
364
  if "offset_mapping" in train_feats.column_names:
365
  train_feats = train_feats.remove_columns(["offset_mapping"])
366
 
367
+ val_feats = preprocess_with_chunking(val_raw, "val", tok)
368
+ # Keep offset_mapping for validation
369
+
370
+ print(f"βœ… Preprocessing completed!")
371
+ print(f" Training features: {len(train_feats)}")
372
+ print(f" Validation features: {len(val_feats)}")
373
 
374
  # ── training args ──────────────────────────────────────────────────
375
  args = TrainingArguments(
376
  output_dir="./cuad_lora_out",
377
  learning_rate=3e-5,
378
  num_train_epochs=4,
379
+ per_device_train_batch_size=16, # Increased for L40S
380
+ per_device_eval_batch_size=16,
381
+ gradient_accumulation_steps=2, # Reduced since batch size increased
382
+ fp16=False, bf16=True,
383
  eval_strategy="steps",
384
  eval_steps=250,
385
  save_steps=500,
 
392
  greater_is_better=True,
393
  logging_steps=50,
394
  report_to="none",
 
395
  resume_from_checkpoint=True,
396
+ dataloader_num_workers=2, # L40S can handle more workers
397
+ dataloader_pin_memory=True,
 
398
  )
399
 
400
  trainer = Trainer(
 
413
  print("βœ… Training completed successfully!")
414
  except Exception as e:
415
  print(f"❌ Training failed: {e}")
 
416
  try:
417
  trainer.save_model("./cuad_lora_out_partial")
418
  tok.save_pretrained("./cuad_lora_out_partial")
419
+ print("πŸ’Ύ Partial model saved")
420
  except:
421
  print("❌ Could not save partial model")
422
  raise e
423
 
424
+ print("βœ… Done. Best F1:", trainer.state.best_metric)
425
  trainer.save_model("./cuad_lora_out")
426
  tok.save_pretrained("./cuad_lora_out")
427
 
428
+ # Push to hub with retry logic
429
  if tokn:
430
+ for attempt in range(3):
 
431
  try:
432
+ print(f"⬆️ Pushing to Hub (attempt {attempt + 1}/3)...")
433
  trainer.push_to_hub(model_repo, private=False)
434
  tok.push_to_hub(model_repo, private=False)
435
  print("πŸš€ Pushed to:", f"https://huggingface.co/{model_repo}")
436
  break
437
  except Exception as e:
438
+ print(f"⚠️ Hub push failed: {e}")
439
+ if attempt < 2:
 
440
  time.sleep(30)
441
  else:
442
+ print("πŸ’Ύ Model saved locally (push failed)")
 
 
443
 
444
+ # Clean up checkpoints
445
  try:
446
  import shutil
447
  shutil.rmtree(CHECKPOINT_DIR)
448
  print("🧹 Cleaned up temporary checkpoints")
449
  except:
450
+ print("⚠️ Could not clean up checkpoints")
451
 
452
  if __name__ == "__main__":
453
  main()