Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# train_cuad_lora_efficient.py - FIXED VERSION | |
""" | |
CUAD fine-tune with LoRA - Fixed for realistic training times | |
""" | |
import os, json, random, gc, time | |
from collections import defaultdict | |
from pathlib import Path | |
import torch, numpy as np | |
from datasets import load_dataset, Dataset, disable_caching | |
from transformers import ( | |
AutoTokenizer, AutoModelForQuestionAnswering, | |
TrainingArguments, default_data_collator, Trainer | |
) | |
from peft import LoraConfig, get_peft_model, TaskType | |
import evaluate | |
from huggingface_hub import login | |
disable_caching() | |
# Set tokenizers parallelism to avoid warnings | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ config ββ | |
MAX_LEN = 512 # Slightly longer context | |
DOC_STRIDE = 256 # Larger stride = fewer chunks = faster training | |
SEED = 42 | |
BATCH_SIZE = 1000 # Process in larger, more efficient batches | |
# Back to reasonable subset size since you've trained 5k before | |
USE_SUBSET = True | |
SUBSET_SIZE = 7000 # Good middle ground - more than your 5k success | |
def set_seed(seed): | |
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
def balance_has_answer(dataset, ratio=2.0, max_samples=None): | |
"""Keep all has-answer rows, down-sample no-answer rows to `ratio`.""" | |
has, no = [], [] | |
for ex in dataset: | |
(has if ex["answers"]["text"] else no).append(ex) | |
print(f"π Original: {len(has)} has-answer, {len(no)} no-answer") | |
# FIXED: Apply max_samples FIRST, then balance | |
if max_samples: | |
total_available = len(has) + len(no) | |
if total_available > max_samples: | |
# Sample proportionally from original distribution | |
has_ratio = len(has) / total_available | |
target_has = int(max_samples * has_ratio) | |
target_no = max_samples - target_has | |
has = random.sample(has, min(target_has, len(has))) | |
no = random.sample(no, min(target_no, len(no))) | |
print(f"π Pre-balance subset: {len(has)} has-answer, {len(no)} no-answer") | |
# Now balance within the subset | |
k = int(len(has) * ratio) | |
if len(no) > k: | |
no = random.sample(no, k) | |
balanced = has + no | |
random.shuffle(balanced) # Shuffle the final dataset | |
print(f"π Final balanced: {len([x for x in balanced if x['answers']['text']])} has-answer, {len([x for x in balanced if not x['answers']['text']])} no-answer") | |
print(f"π Total examples: {len(balanced)}") | |
return Dataset.from_list(balanced) | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ postproc ββ | |
metric = evaluate.load("squad") | |
def postprocess_qa(examples, features, raw_predictions, tokenizer): | |
"""HF-style span extraction + n-best, returns SQuAD format dict.""" | |
all_start, all_end = raw_predictions | |
example_id_to_index = {k: i for i, k in enumerate(examples["id"])} | |
features_per_example = defaultdict(list) | |
for i, feat_id in enumerate(features["example_id"]): | |
features_per_example[example_id_to_index[feat_id]].append(i) | |
predictions = [] | |
for example_idx, example in enumerate(examples): | |
best_score = -1e9 | |
best_span = "" | |
context = example["context"] | |
for feat_idx in features_per_example[example_idx]: | |
start_logit = all_start[feat_idx] | |
end_logit = all_end[feat_idx] | |
offset = features["offset_mapping"][feat_idx] | |
start_idx = int(np.argmax(start_logit)) | |
end_idx = int(np.argmax(end_logit)) | |
if start_idx <= end_idx < len(offset): | |
start_char, _ = offset[start_idx] | |
_, end_char = offset[end_idx] | |
span = context[start_char:end_char].strip() | |
score = start_logit[start_idx] + end_logit[end_idx] | |
if score > best_score and span: | |
best_score, best_span = score, span | |
predictions.append( | |
{"id": example["id"], "prediction_text": best_span} | |
) | |
return predictions | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ preprocessing ββ | |
def preprocess_training_batch(examples, tokenizer): | |
"""Training preprocessing - NO offset_mapping included""" | |
questions = examples["question"] | |
contexts = examples["context"] | |
tokenized_examples = tokenizer( | |
questions, | |
contexts, | |
truncation="only_second", | |
max_length=MAX_LEN, | |
stride=DOC_STRIDE, | |
return_overflowing_tokens=True, | |
return_offsets_mapping=True, | |
padding="max_length", | |
) | |
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") | |
offset_mapping = tokenized_examples.pop("offset_mapping") | |
start_positions = [] | |
end_positions = [] | |
for i, offsets in enumerate(offset_mapping): | |
cls_index = 0 | |
sample_index = sample_mapping[i] | |
answers = examples["answers"][sample_index] | |
if not answers["text"] or not answers["text"][0]: | |
start_positions.append(cls_index) | |
end_positions.append(cls_index) | |
continue | |
answer_start_char = answers["answer_start"][0] | |
answer_text = answers["text"][0] | |
answer_end_char = answer_start_char + len(answer_text) | |
token_start_index = cls_index | |
token_end_index = cls_index | |
for token_index, (start_char, end_char) in enumerate(offsets): | |
if start_char <= answer_start_char < end_char: | |
token_start_index = token_index | |
if start_char < answer_end_char <= end_char: | |
token_end_index = token_index | |
break | |
if token_start_index <= token_end_index and token_start_index > 0: | |
start_positions.append(token_start_index) | |
end_positions.append(token_end_index) | |
else: | |
start_positions.append(cls_index) | |
end_positions.append(cls_index) | |
tokenized_examples["start_positions"] = start_positions | |
tokenized_examples["end_positions"] = end_positions | |
return tokenized_examples | |
def preprocess_validation_batch(examples, tokenizer): | |
"""Validation preprocessing - INCLUDES offset_mapping and example_id""" | |
questions = examples["question"] | |
contexts = examples["context"] | |
tokenized_examples = tokenizer( | |
questions, | |
contexts, | |
truncation="only_second", | |
max_length=MAX_LEN, | |
stride=DOC_STRIDE, | |
return_overflowing_tokens=True, | |
return_offsets_mapping=True, | |
padding="max_length", | |
) | |
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") | |
tokenized_examples["example_id"] = [ | |
examples["id"][sample_mapping[i]] for i in range(len(tokenized_examples["input_ids"])) | |
] | |
return tokenized_examples | |
def preprocess_dataset_streaming(dataset, tokenizer, desc="Processing", is_training=True): | |
"""Process dataset in batches using HuggingFace's map function with batching.""" | |
print(f"π {desc} dataset with batch processing...") | |
if is_training: | |
preprocess_fn = preprocess_training_batch | |
else: | |
preprocess_fn = preprocess_validation_batch | |
processed = dataset.map( | |
lambda examples: preprocess_fn(examples, tokenizer), | |
batched=True, | |
batch_size=BATCH_SIZE, | |
remove_columns=dataset.column_names, | |
desc=desc, | |
num_proc=1, | |
) | |
print(f"β {desc} completed: {len(processed)} features") | |
return processed | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ main ββ | |
def main(): | |
set_seed(SEED) | |
model_repo = os.getenv("MODEL_NAME", "AvocadoMuffin/roberta-cuad-qa-v4") | |
if (tokn := os.getenv("roberta_token")): | |
try: | |
login(tokn) | |
print("π HuggingFace Hub login OK") | |
except Exception as e: | |
print(f"β οΈ Hub login failed: {e}") | |
tokn = None | |
print("π Loading CUADβ¦") | |
try: | |
cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True) | |
print(f"β Loaded {len(cuad)} examples") | |
except Exception as e: | |
print(f"β Dataset loading failed: {e}") | |
cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True, download_mode="force_redownload") | |
cuad = cuad.shuffle(seed=SEED) | |
# FIXED: Apply subset reduction more aggressively | |
subset_size = SUBSET_SIZE if USE_SUBSET else None | |
cuad = balance_has_answer(cuad, ratio=1.5, max_samples=subset_size) # Reduced ratio too | |
print(f"π Final dataset size: {len(cuad)} examples") | |
# Estimate features after preprocessing | |
avg_features_per_example = 2.5 # Conservative estimate with stride | |
estimated_features = len(cuad) * avg_features_per_example | |
print(f"π Estimated training features: ~{int(estimated_features)}") | |
ds = cuad.train_test_split(test_size=0.1, seed=SEED) | |
train_raw, val_raw = ds["train"], ds["test"] | |
# ββ tokeniser & model ββ | |
base_ckpt = "deepset/roberta-base-squad2" | |
tok = AutoTokenizer.from_pretrained(base_ckpt, use_fast=True) | |
model = AutoModelForQuestionAnswering.from_pretrained(base_ckpt) | |
# FIXED: Lighter LoRA config for faster training | |
lora = LoraConfig( | |
task_type=TaskType.QUESTION_ANS, | |
r=16, # Reduced from 32 | |
lora_alpha=32, # Reduced from 64 | |
lora_dropout=0.1, | |
target_modules=["query", "value"], # Fewer modules | |
) | |
model = get_peft_model(model, lora) | |
model.print_trainable_parameters() | |
# ββ preprocessing βββββββββββββββββββββββββββββββββββββββββ | |
print("π Starting preprocessing...") | |
train_feats = preprocess_dataset_streaming(train_raw, tok, "Training", is_training=True) | |
val_feats = preprocess_dataset_streaming(val_raw, tok, "Validation", is_training=False) | |
print(f"β Preprocessing completed!") | |
print(f" Training features: {len(train_feats)}") | |
print(f" Validation features: {len(val_feats)}") | |
# ββ training args - FIXED for reasonable training time ββ | |
batch_size = 16 # Good balance | |
gradient_accumulation_steps = 2 | |
effective_batch_size = batch_size * gradient_accumulation_steps | |
num_epochs = 3 # Keep it reasonable | |
steps_per_epoch = len(train_feats) // effective_batch_size | |
total_steps = steps_per_epoch * num_epochs | |
eval_steps = max(25, steps_per_epoch // 8) # More frequent eval | |
save_steps = eval_steps * 3 | |
print(f"π Training configuration:") | |
print(f" Effective batch size: {effective_batch_size}") | |
print(f" Steps per epoch: {steps_per_epoch}") | |
print(f" Total steps: {total_steps}") | |
print(f" Estimated time: ~{total_steps/2.4/60:.1f} minutes") | |
print(f" Eval every: {eval_steps} steps") | |
args = TrainingArguments( | |
output_dir="./cuad_lora_out", | |
learning_rate=3e-5, # Slightly lower LR | |
num_train_epochs=num_epochs, | |
per_device_train_batch_size=batch_size, | |
per_device_eval_batch_size=8, | |
gradient_accumulation_steps=gradient_accumulation_steps, | |
fp16=False, bf16=True, | |
eval_strategy="steps", | |
eval_steps=eval_steps, | |
save_steps=save_steps, | |
save_total_limit=2, | |
weight_decay=0.01, | |
lr_scheduler_type="cosine", | |
warmup_ratio=0.1, | |
load_best_model_at_end=False, | |
logging_steps=10, # More frequent logging | |
report_to="none", | |
dataloader_num_workers=2, | |
dataloader_pin_memory=True, | |
remove_unused_columns=True, | |
) | |
trainer = Trainer( | |
model=model, | |
args=args, | |
train_dataset=train_feats, | |
eval_dataset=val_feats, | |
tokenizer=tok, | |
data_collator=default_data_collator, | |
compute_metrics=None, | |
) | |
print("π Trainingβ¦") | |
try: | |
trainer.train() | |
print("β Training completed successfully!") | |
except Exception as e: | |
print(f"β Training failed: {e}") | |
try: | |
trainer.save_model("./cuad_lora_out_partial") | |
tok.save_pretrained("./cuad_lora_out_partial") | |
print("πΎ Partial model saved") | |
except: | |
print("β Could not save partial model") | |
raise e | |
print("β Done. Best eval_loss:", trainer.state.best_metric) | |
trainer.save_model("./cuad_lora_out") | |
tok.save_pretrained("./cuad_lora_out") | |
# Push to hub | |
if tokn: | |
for attempt in range(3): | |
try: | |
print(f"β¬οΈ Pushing to Hub (attempt {attempt + 1}/3)...") | |
trainer.push_to_hub(model_repo, private=False) | |
tok.push_to_hub(model_repo, private=False) | |
print("π Pushed to:", f"https://huggingface.co/{model_repo}") | |
break | |
except Exception as e: | |
print(f"β οΈ Hub push failed: {e}") | |
if attempt < 2: | |
time.sleep(30) | |
else: | |
print("πΎ Model saved locally (push failed)") | |
if __name__ == "__main__": | |
main() |