AvocadoMuffin commited on
Commit
29ee80b
Β·
verified Β·
1 Parent(s): bad80b7

Create train1.py

Browse files
Files changed (1) hide show
  1. train1.py +379 -0
train1.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, gc, os, numpy as np, evaluate, json
2
+ from datasets import load_dataset
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoModelForQuestionAnswering,
6
+ TrainingArguments,
7
+ Trainer,
8
+ default_data_collator
9
+ )
10
+ from peft import LoraConfig, get_peft_model, TaskType
11
+ from huggingface_hub import login
12
+ import sys
13
+
14
+ def main():
15
+ # Get model name from environment
16
+ model_name = os.environ.get('MODEL_NAME', 'roberta-cuad-qa')
17
+
18
+ # Login to HF Hub
19
+ hf_token = os.environ.get('roberta_token')
20
+ if hf_token:
21
+ login(token=hf_token)
22
+ print("βœ… Logged into Hugging Face Hub")
23
+ else:
24
+ print("⚠️ No HF_TOKEN found - model won't be pushed to Hub")
25
+
26
+ # Setup
27
+ torch.cuda.empty_cache()
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ print(f"πŸ”§ Using device: {device}")
30
+ if torch.cuda.is_available():
31
+ print(f"🎯 GPU: {torch.cuda.get_device_name()}")
32
+ print(f"πŸ’Ύ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
33
+
34
+ # Load and prepare data - REDUCED SIZE FOR FASTER TRAINING
35
+ print("πŸ“š Loading CUAD dataset...")
36
+ raw = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True)
37
+
38
+ # Use 5000 samples for good model quality - expect ~1 hour training
39
+ N = 5000
40
+ raw = raw.shuffle(seed=42).select(range(min(N, len(raw))))
41
+ ds = raw.train_test_split(test_size=0.1, seed=42)
42
+ train_ds, val_ds = ds["train"], ds["test"]
43
+ print(f"βœ… Data loaded - Train: {len(train_ds)}, Val: {len(val_ds)}")
44
+
45
+ # Store original validation data for metrics
46
+ print("πŸ“Š Preparing metrics data...")
47
+ original_val_data = []
48
+ val_sample_mapping = [] # Track which tokenized sample maps to which original
49
+
50
+ for i, ex in enumerate(val_ds):
51
+ original_val_data.append(ex["answers"])
52
+
53
+ # Load model and tokenizer
54
+ print("πŸ€– Loading RoBERTa model...")
55
+ base_model = "roberta-base"
56
+ tok = AutoTokenizer.from_pretrained(base_model, use_fast=True)
57
+ model = AutoModelForQuestionAnswering.from_pretrained(base_model)
58
+
59
+ # Add LoRA
60
+ print("πŸ”§ Adding LoRA adapters...")
61
+ lora_cfg = LoraConfig(
62
+ task_type=TaskType.QUESTION_ANS,
63
+ target_modules=["query", "value"],
64
+ r=16,
65
+ lora_alpha=32,
66
+ lora_dropout=0.05,
67
+ )
68
+ model = get_peft_model(model, lora_cfg)
69
+ model.print_trainable_parameters()
70
+ model.to(device)
71
+
72
+ # Tokenization function - AGGRESSIVE OPTIMIZATION TO PREVENT EXPANSION
73
+ max_len, doc_stride = 512, 400 # Much larger stride to minimize chunks per document
74
+
75
+ def preprocess(examples):
76
+ tok_batch = tok(
77
+ examples["question"],
78
+ examples["context"],
79
+ truncation="only_second",
80
+ max_length=max_len,
81
+ stride=doc_stride,
82
+ return_overflowing_tokens=True,
83
+ return_offsets_mapping=True,
84
+ padding="max_length",
85
+ )
86
+
87
+ sample_map = tok_batch.pop("overflow_to_sample_mapping")
88
+ offset_map = tok_batch.pop("offset_mapping")
89
+
90
+ start_pos, end_pos = [], []
91
+ for i, offsets in enumerate(offset_map):
92
+ cls_idx = tok_batch["input_ids"][i].index(tok.cls_token_id)
93
+ sample_idx = sample_map[i]
94
+ answer = examples["answers"][sample_idx]
95
+
96
+ if len(answer["answer_start"]) == 0:
97
+ start_pos.append(cls_idx)
98
+ end_pos.append(cls_idx)
99
+ continue
100
+
101
+ s_char = answer["answer_start"][0]
102
+ e_char = s_char + len(answer["text"][0])
103
+ seq_ids = tok_batch.sequence_ids(i)
104
+ c0, c1 = seq_ids.index(1), len(seq_ids) - 1 - seq_ids[::-1].index(1)
105
+
106
+ if not (offsets[c0][0] <= s_char <= offsets[c1][1]):
107
+ start_pos.append(cls_idx)
108
+ end_pos.append(cls_idx)
109
+ continue
110
+
111
+ st = c0
112
+ while st <= c1 and offsets[st][0] <= s_char:
113
+ st += 1
114
+
115
+ en = c1
116
+ while en >= c0 and offsets[en][1] >= e_char:
117
+ en -= 1
118
+
119
+ # Fixed position calculation with bounds checking
120
+ start_pos.append(max(c0, min(st - 1, c1)))
121
+ end_pos.append(max(c0, min(en + 1, c1)))
122
+
123
+ tok_batch["start_positions"] = start_pos
124
+ tok_batch["end_positions"] = end_pos
125
+
126
+ # Store sample mapping for metrics calculation
127
+ tok_batch["sample_mapping"] = sample_map
128
+ return tok_batch
129
+
130
+ # Tokenize datasets
131
+ print("πŸ”„ Tokenizing datasets...")
132
+ train_tok = train_ds.map(
133
+ preprocess,
134
+ batched=True,
135
+ batch_size=50, # Smaller batch size for preprocessing
136
+ remove_columns=train_ds.column_names,
137
+ desc="Tokenizing train"
138
+ )
139
+
140
+ val_tok = val_ds.map(
141
+ preprocess,
142
+ batched=True,
143
+ batch_size=50,
144
+ remove_columns=val_ds.column_names,
145
+ desc="Tokenizing validation"
146
+ )
147
+
148
+ # DEBUG: Print actual dataset sizes after tokenization
149
+ print(f"πŸ” DEBUG INFO:")
150
+ print(f" Original samples: {N}")
151
+ print(f" After tokenization - Train: {len(train_tok)}, Val: {len(val_tok)}")
152
+ print(f" Expansion factor: {len(train_tok)/len(train_ds):.1f}x")
153
+
154
+ # SAFETY CHECK: If expansion is too high, reduce data size automatically
155
+ expansion_factor = len(train_tok) / len(train_ds)
156
+ if expansion_factor > 12: # Slightly more permissive for 4K samples
157
+ print(f"⚠️ HIGH EXPANSION DETECTED ({expansion_factor:.1f}x)!")
158
+ print("πŸ”§ Auto-reducing dataset size to prevent excessively slow training...")
159
+
160
+ # Allow up to 20k samples for 1 hour training
161
+ target_size = min(20000, len(train_tok)) # Max 20k samples
162
+ train_indices = list(range(0, len(train_tok), max(1, len(train_tok) // target_size)))[:target_size]
163
+ val_indices = list(range(0, len(val_tok), max(1, len(val_tok) // (target_size // 10))))[:target_size // 10]
164
+
165
+ train_tok = train_tok.select(train_indices)
166
+ val_tok = val_tok.select(val_indices)
167
+
168
+ print(f"βœ… Reduced to - Train: {len(train_tok)}, Val: {len(val_tok)}")
169
+ print(f"πŸ“ˆ This should complete in ~45-75 minutes")
170
+
171
+ # Clean up memory
172
+ del raw, ds, train_ds, val_ds
173
+ gc.collect()
174
+ torch.cuda.empty_cache()
175
+
176
+ # Metrics setup
177
+ metric = evaluate.load("squad")
178
+
179
+ def postprocess(preds, dataset):
180
+ starts, ends = preds
181
+ answers = []
182
+ for i in range(len(starts)):
183
+ a, b = int(np.argmax(starts[i])), int(np.argmax(ends[i]))
184
+ if a > b:
185
+ a, b = b, a
186
+ text = tok.decode(dataset[i]["input_ids"][a:b+1], skip_special_tokens=True)
187
+ answers.append(text.strip())
188
+ return answers
189
+
190
+ def compute_metrics(eval_pred):
191
+ try:
192
+ preds, _ = eval_pred
193
+ starts, ends = preds
194
+
195
+ # Group predictions by original sample (handle multiple chunks per sample)
196
+ sample_predictions = {}
197
+ for i in range(len(starts)):
198
+ # Get which original sample this tokenized example came from
199
+ if hasattr(val_tok[i], 'sample_mapping') and 'sample_mapping' in val_tok[i]:
200
+ orig_idx = val_tok[i]['sample_mapping']
201
+ else:
202
+ # Fallback: assume 1:1 mapping (may be inaccurate with chunking)
203
+ orig_idx = min(i, len(original_val_data) - 1)
204
+
205
+ # Get best answer span for this chunk
206
+ start_idx = int(np.argmax(starts[i]))
207
+ end_idx = int(np.argmax(ends[i]))
208
+ if start_idx > end_idx:
209
+ start_idx, end_idx = end_idx, start_idx
210
+
211
+ # Extract answer text
212
+ answer_text = tok.decode(
213
+ val_tok[i]["input_ids"][start_idx:end_idx+1],
214
+ skip_special_tokens=True
215
+ ).strip()
216
+
217
+ # Store best prediction for this original sample
218
+ confidence = float(starts[i][start_idx]) + float(ends[i][end_idx])
219
+ if orig_idx not in sample_predictions or confidence > sample_predictions[orig_idx][1]:
220
+ sample_predictions[orig_idx] = (answer_text, confidence)
221
+
222
+ # Format for SQuAD metric
223
+ predictions = []
224
+ references = []
225
+ for orig_idx in range(len(original_val_data)):
226
+ pred_text = sample_predictions.get(orig_idx, ("", 0))[0]
227
+ predictions.append({
228
+ "id": str(orig_idx),
229
+ "prediction_text": pred_text
230
+ })
231
+ references.append({
232
+ "id": str(orig_idx),
233
+ "answers": original_val_data[orig_idx]
234
+ })
235
+
236
+ result = metric.compute(predictions=predictions, references=references)
237
+
238
+ # Add some debugging info
239
+ print(f"πŸ“Š Evaluation: EM={result['exact_match']:.3f}, F1={result['f1']:.3f}")
240
+ return result
241
+
242
+ except Exception as e:
243
+ print(f"⚠️ Metrics computation failed: {e}")
244
+ print(f" Pred shape: {np.array(preds).shape if preds else 'None'}")
245
+ print(f" Val dataset size: {len(val_tok)}")
246
+ print(f" Original val size: {len(original_val_data)}")
247
+ return {"exact_match": 0.0, "f1": 0.0}
248
+
249
+ # OPTIMIZED Training arguments
250
+ output_dir = "./model_output"
251
+ args = TrainingArguments(
252
+ output_dir=output_dir,
253
+ per_device_train_batch_size=8, # INCREASED from 2
254
+ per_device_eval_batch_size=8, # INCREASED from 4
255
+ gradient_accumulation_steps=2, # REDUCED from 8
256
+ num_train_epochs=3, # Back to 3 epochs for better training
257
+ learning_rate=5e-4,
258
+ lr_scheduler_type="cosine",
259
+ warmup_ratio=0.1,
260
+ bf16=True, # CHANGED from fp16 (better for newer GPUs)
261
+ eval_strategy="steps",
262
+ eval_steps=100, # REDUCED from 250
263
+ save_steps=200, # REDUCED from 500
264
+ save_total_limit=2,
265
+ logging_steps=25, # REDUCED from 50
266
+ weight_decay=0.01,
267
+ remove_unused_columns=True,
268
+ report_to=None,
269
+ push_to_hub=False,
270
+ dataloader_pin_memory=True, # CHANGED to True for faster data loading
271
+ dataloader_num_workers=4, # ADDED for parallel data loading
272
+ gradient_checkpointing=False, # DISABLED to trade memory for speed
273
+ )
274
+
275
+ # Create trainer
276
+ trainer = Trainer(
277
+ model=model,
278
+ args=args,
279
+ train_dataset=train_tok,
280
+ eval_dataset=val_tok,
281
+ tokenizer=tok,
282
+ data_collator=default_data_collator,
283
+ compute_metrics=compute_metrics,
284
+ )
285
+
286
+ print(f"πŸš€ Starting training...")
287
+ print(f"πŸ“Š Total training samples: {len(train_tok)}")
288
+ print(f"πŸ“Š Total validation samples: {len(val_tok)}")
289
+ print(f"⚑ Effective batch size: {args.per_device_train_batch_size * args.gradient_accumulation_steps}")
290
+
291
+ if torch.cuda.is_available():
292
+ print(f"πŸ’Ύ GPU memory before training: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
293
+
294
+ # Training loop with error handling
295
+ try:
296
+ trainer.train()
297
+ print("βœ… Training completed successfully!")
298
+ except RuntimeError as e:
299
+ if "CUDA out of memory" in str(e):
300
+ print("⚠️ GPU OOM - reducing batch size and retrying...")
301
+ torch.cuda.empty_cache()
302
+ gc.collect()
303
+
304
+ # Reduce batch size
305
+ args.per_device_train_batch_size = 4
306
+ args.gradient_accumulation_steps = 4
307
+ trainer = Trainer(
308
+ model=model,
309
+ args=args,
310
+ train_dataset=train_tok,
311
+ eval_dataset=val_tok,
312
+ tokenizer=tok,
313
+ data_collator=default_data_collator,
314
+ compute_metrics=compute_metrics,
315
+ )
316
+ trainer.train()
317
+ print("βœ… Training completed with reduced batch size!")
318
+ else:
319
+ raise e
320
+
321
+ # Save model locally first
322
+ print("πŸ’Ύ Saving model locally...")
323
+ os.makedirs(output_dir, exist_ok=True)
324
+ trainer.model.save_pretrained(output_dir)
325
+ tok.save_pretrained(output_dir)
326
+
327
+ # Save training info
328
+ training_info = {
329
+ "model_name": model_name,
330
+ "base_model": base_model,
331
+ "dataset": "theatticusproject/cuad-qa",
332
+ "original_samples": N,
333
+ "training_samples_after_tokenization": len(train_tok),
334
+ "validation_samples_after_tokenization": len(val_tok),
335
+ "lora_config": {
336
+ "r": lora_cfg.r,
337
+ "lora_alpha": lora_cfg.lora_alpha,
338
+ "target_modules": lora_cfg.target_modules,
339
+ "lora_dropout": lora_cfg.lora_dropout,
340
+ },
341
+ "training_args": {
342
+ "batch_size": args.per_device_train_batch_size,
343
+ "gradient_accumulation_steps": args.gradient_accumulation_steps,
344
+ "effective_batch_size": args.per_device_train_batch_size * args.gradient_accumulation_steps,
345
+ "epochs": args.num_train_epochs,
346
+ "learning_rate": args.learning_rate,
347
+ }
348
+ }
349
+
350
+ with open(os.path.join(output_dir, "training_info.json"), "w") as f:
351
+ json.dump(training_info, f, indent=2)
352
+
353
+ # Push to Hub if token available
354
+ if hf_token:
355
+ try:
356
+ print(f"⬆️ Pushing model to Hub: {model_name}")
357
+ trainer.model.push_to_hub(model_name, private=False)
358
+ tok.push_to_hub(model_name, private=False)
359
+
360
+ # Also push training info
361
+ from huggingface_hub import upload_file
362
+ upload_file(
363
+ path_or_fileobj=os.path.join(output_dir, "training_info.json"),
364
+ path_in_repo="training_info.json",
365
+ repo_id=model_name,
366
+ repo_type="model"
367
+ )
368
+
369
+ print(f"πŸŽ‰ Model successfully saved to: https://huggingface.co/{model_name}")
370
+ except Exception as e:
371
+ print(f"❌ Failed to push to Hub: {e}")
372
+ print("πŸ’Ύ Model saved locally in ./model_output/")
373
+ else:
374
+ print("πŸ’Ύ Model saved locally in ./model_output/ (no HF token for Hub upload)")
375
+
376
+ print("🏁 Training pipeline completed!")
377
+
378
+ if __name__ == "__main__":
379
+ main()