#!/usr/bin/env python # train_cuad_lora_efficient.py - FIXED VERSION """ CUAD fine-tune with LoRA - Fixed for realistic training times """ import os, json, random, gc, time from collections import defaultdict from pathlib import Path import torch, numpy as np from datasets import load_dataset, Dataset, disable_caching from transformers import ( AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, default_data_collator, Trainer ) from peft import LoraConfig, get_peft_model, TaskType import evaluate from huggingface_hub import login disable_caching() # Set tokenizers parallelism to avoid warnings os.environ["TOKENIZERS_PARALLELISM"] = "false" # ─────────────────────────────────────────────────────────────── config ── MAX_LEN = 512 # Slightly longer context DOC_STRIDE = 256 # Larger stride = fewer chunks = faster training SEED = 42 BATCH_SIZE = 1000 # Process in larger, more efficient batches # Back to reasonable subset size since you've trained 5k before USE_SUBSET = True SUBSET_SIZE = 7000 # Good middle ground - more than your 5k success def set_seed(seed): random.seed(seed); np.random.seed(seed); torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def balance_has_answer(dataset, ratio=2.0, max_samples=None): """Keep all has-answer rows, down-sample no-answer rows to `ratio`.""" has, no = [], [] for ex in dataset: (has if ex["answers"]["text"] else no).append(ex) print(f"📊 Original: {len(has)} has-answer, {len(no)} no-answer") # FIXED: Apply max_samples FIRST, then balance if max_samples: total_available = len(has) + len(no) if total_available > max_samples: # Sample proportionally from original distribution has_ratio = len(has) / total_available target_has = int(max_samples * has_ratio) target_no = max_samples - target_has has = random.sample(has, min(target_has, len(has))) no = random.sample(no, min(target_no, len(no))) print(f"📉 Pre-balance subset: {len(has)} has-answer, {len(no)} no-answer") # Now balance within the subset k = int(len(has) * ratio) if len(no) > k: no = random.sample(no, k) balanced = has + no random.shuffle(balanced) # Shuffle the final dataset print(f"📊 Final balanced: {len([x for x in balanced if x['answers']['text']])} has-answer, {len([x for x in balanced if not x['answers']['text']])} no-answer") print(f"📊 Total examples: {len(balanced)}") return Dataset.from_list(balanced) # ────────────────────────────────────────────────────────────── postproc ── metric = evaluate.load("squad") def postprocess_qa(examples, features, raw_predictions, tokenizer): """HF-style span extraction + n-best, returns SQuAD format dict.""" all_start, all_end = raw_predictions example_id_to_index = {k: i for i, k in enumerate(examples["id"])} features_per_example = defaultdict(list) for i, feat_id in enumerate(features["example_id"]): features_per_example[example_id_to_index[feat_id]].append(i) predictions = [] for example_idx, example in enumerate(examples): best_score = -1e9 best_span = "" context = example["context"] for feat_idx in features_per_example[example_idx]: start_logit = all_start[feat_idx] end_logit = all_end[feat_idx] offset = features["offset_mapping"][feat_idx] start_idx = int(np.argmax(start_logit)) end_idx = int(np.argmax(end_logit)) if start_idx <= end_idx < len(offset): start_char, _ = offset[start_idx] _, end_char = offset[end_idx] span = context[start_char:end_char].strip() score = start_logit[start_idx] + end_logit[end_idx] if score > best_score and span: best_score, best_span = score, span predictions.append( {"id": example["id"], "prediction_text": best_span} ) return predictions # ───────────────────────────────────────────────────────────── preprocessing ── def preprocess_training_batch(examples, tokenizer): """Training preprocessing - NO offset_mapping included""" questions = examples["question"] contexts = examples["context"] tokenized_examples = tokenizer( questions, contexts, truncation="only_second", max_length=MAX_LEN, stride=DOC_STRIDE, return_overflowing_tokens=True, return_offsets_mapping=True, padding="max_length", ) sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") offset_mapping = tokenized_examples.pop("offset_mapping") start_positions = [] end_positions = [] for i, offsets in enumerate(offset_mapping): cls_index = 0 sample_index = sample_mapping[i] answers = examples["answers"][sample_index] if not answers["text"] or not answers["text"][0]: start_positions.append(cls_index) end_positions.append(cls_index) continue answer_start_char = answers["answer_start"][0] answer_text = answers["text"][0] answer_end_char = answer_start_char + len(answer_text) token_start_index = cls_index token_end_index = cls_index for token_index, (start_char, end_char) in enumerate(offsets): if start_char <= answer_start_char < end_char: token_start_index = token_index if start_char < answer_end_char <= end_char: token_end_index = token_index break if token_start_index <= token_end_index and token_start_index > 0: start_positions.append(token_start_index) end_positions.append(token_end_index) else: start_positions.append(cls_index) end_positions.append(cls_index) tokenized_examples["start_positions"] = start_positions tokenized_examples["end_positions"] = end_positions return tokenized_examples def preprocess_validation_batch(examples, tokenizer): """Validation preprocessing - INCLUDES offset_mapping and example_id""" questions = examples["question"] contexts = examples["context"] tokenized_examples = tokenizer( questions, contexts, truncation="only_second", max_length=MAX_LEN, stride=DOC_STRIDE, return_overflowing_tokens=True, return_offsets_mapping=True, padding="max_length", ) sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") tokenized_examples["example_id"] = [ examples["id"][sample_mapping[i]] for i in range(len(tokenized_examples["input_ids"])) ] return tokenized_examples def preprocess_dataset_streaming(dataset, tokenizer, desc="Processing", is_training=True): """Process dataset in batches using HuggingFace's map function with batching.""" print(f"🔄 {desc} dataset with batch processing...") if is_training: preprocess_fn = preprocess_training_batch else: preprocess_fn = preprocess_validation_batch processed = dataset.map( lambda examples: preprocess_fn(examples, tokenizer), batched=True, batch_size=BATCH_SIZE, remove_columns=dataset.column_names, desc=desc, num_proc=1, ) print(f"✅ {desc} completed: {len(processed)} features") return processed # ───────────────────────────────────────────────────────────────── main ── def main(): set_seed(SEED) model_repo = os.getenv("MODEL_NAME", "AvocadoMuffin/roberta-cuad-qa-v4") if (tokn := os.getenv("roberta_token")): try: login(tokn) print("🔑 HuggingFace Hub login OK") except Exception as e: print(f"⚠️ Hub login failed: {e}") tokn = None print("📚 Loading CUAD…") try: cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True) print(f"✅ Loaded {len(cuad)} examples") except Exception as e: print(f"❌ Dataset loading failed: {e}") cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True, download_mode="force_redownload") cuad = cuad.shuffle(seed=SEED) # FIXED: Apply subset reduction more aggressively subset_size = SUBSET_SIZE if USE_SUBSET else None cuad = balance_has_answer(cuad, ratio=1.5, max_samples=subset_size) # Reduced ratio too print(f"📊 Final dataset size: {len(cuad)} examples") # Estimate features after preprocessing avg_features_per_example = 2.5 # Conservative estimate with stride estimated_features = len(cuad) * avg_features_per_example print(f"📊 Estimated training features: ~{int(estimated_features)}") ds = cuad.train_test_split(test_size=0.1, seed=SEED) train_raw, val_raw = ds["train"], ds["test"] # ── tokeniser & model ── base_ckpt = "deepset/roberta-base-squad2" tok = AutoTokenizer.from_pretrained(base_ckpt, use_fast=True) model = AutoModelForQuestionAnswering.from_pretrained(base_ckpt) # FIXED: Lighter LoRA config for faster training lora = LoraConfig( task_type=TaskType.QUESTION_ANS, r=16, # Reduced from 32 lora_alpha=32, # Reduced from 64 lora_dropout=0.1, target_modules=["query", "value"], # Fewer modules ) model = get_peft_model(model, lora) model.print_trainable_parameters() # ── preprocessing ───────────────────────────────────────── print("🔄 Starting preprocessing...") train_feats = preprocess_dataset_streaming(train_raw, tok, "Training", is_training=True) val_feats = preprocess_dataset_streaming(val_raw, tok, "Validation", is_training=False) print(f"✅ Preprocessing completed!") print(f" Training features: {len(train_feats)}") print(f" Validation features: {len(val_feats)}") # ── training args - FIXED for reasonable training time ── batch_size = 16 # Good balance gradient_accumulation_steps = 2 effective_batch_size = batch_size * gradient_accumulation_steps num_epochs = 3 # Keep it reasonable steps_per_epoch = len(train_feats) // effective_batch_size total_steps = steps_per_epoch * num_epochs eval_steps = max(25, steps_per_epoch // 8) # More frequent eval save_steps = eval_steps * 3 print(f"📊 Training configuration:") print(f" Effective batch size: {effective_batch_size}") print(f" Steps per epoch: {steps_per_epoch}") print(f" Total steps: {total_steps}") print(f" Estimated time: ~{total_steps/2.4/60:.1f} minutes") print(f" Eval every: {eval_steps} steps") args = TrainingArguments( output_dir="./cuad_lora_out", learning_rate=3e-5, # Slightly lower LR num_train_epochs=num_epochs, per_device_train_batch_size=batch_size, per_device_eval_batch_size=8, gradient_accumulation_steps=gradient_accumulation_steps, fp16=False, bf16=True, eval_strategy="steps", eval_steps=eval_steps, save_steps=save_steps, save_total_limit=2, weight_decay=0.01, lr_scheduler_type="cosine", warmup_ratio=0.1, load_best_model_at_end=False, logging_steps=10, # More frequent logging report_to="none", dataloader_num_workers=2, dataloader_pin_memory=True, remove_unused_columns=True, ) trainer = Trainer( model=model, args=args, train_dataset=train_feats, eval_dataset=val_feats, tokenizer=tok, data_collator=default_data_collator, compute_metrics=None, ) print("🚀 Training…") try: trainer.train() print("✅ Training completed successfully!") except Exception as e: print(f"❌ Training failed: {e}") try: trainer.save_model("./cuad_lora_out_partial") tok.save_pretrained("./cuad_lora_out_partial") print("💾 Partial model saved") except: print("❌ Could not save partial model") raise e print("✅ Done. Best eval_loss:", trainer.state.best_metric) trainer.save_model("./cuad_lora_out") tok.save_pretrained("./cuad_lora_out") # Push to hub if tokn: for attempt in range(3): try: print(f"⬆️ Pushing to Hub (attempt {attempt + 1}/3)...") trainer.push_to_hub(model_repo, private=False) tok.push_to_hub(model_repo, private=False) print("🚀 Pushed to:", f"https://huggingface.co/{model_repo}") break except Exception as e: print(f"⚠️ Hub push failed: {e}") if attempt < 2: time.sleep(30) else: print("💾 Model saved locally (push failed)") if __name__ == "__main__": main()