AvocadoMuffin's picture
Update train.py
d8f4b0f verified
raw
history blame
7.33 kB
#!/usr/bin/env python
# train_cuad_lora.py
"""
CUAD fine-tune with LoRA on an L4 / T4 GPU.
Expected wall-clock on Nvidia L4: ~25-30 min.
"""
import os, json, random, gc
from collections import defaultdict
import torch, numpy as np
from datasets import load_dataset, Dataset, disable_caching
from transformers import (
AutoTokenizer, AutoModelForQuestionAnswering,
TrainingArguments, default_data_collator
)
from transformers import QuestionAnsweringTrainer, EvalPrediction
from peft import LoraConfig, get_peft_model, TaskType
import evaluate
from huggingface_hub import login
disable_caching() # avoids giant disk cache on Colab
# ─────────────────────────────────────────────────────────────── helpers ──
MAX_LEN = 384 # window
DOC_STRIDE = 128
SEED = 42
def set_seed(seed):
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def balance_has_answer(dataset, ratio=2.0):
"""Keep all has-answer rows, down-sample no-answer rows to `ratio`."""
has, no = [], []
for ex in dataset:
(has if ex["answers"]["text"] else no).append(ex)
k = int(len(has) * ratio)
no = random.sample(no, min(k, len(no)))
return Dataset.from_list(has + no)
# ────────────────────────────────────────────────────────────── postproc ──
metric = evaluate.load("squad")
def postprocess_qa(examples, features, raw_predictions, tokenizer):
"""HF-style span extraction + n-best, returns SQuAD format dict."""
all_start, all_end = raw_predictions
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
features_per_example = defaultdict(list)
for i, feat_id in enumerate(features["example_id"]):
features_per_example[example_id_to_index[feat_id]].append(i)
predictions = []
for example_idx, example in enumerate(examples):
best_score = -1e9
best_span = ""
context = example["context"]
for feat_idx in features_per_example[example_idx]:
start_logit = all_start[feat_idx]
end_logit = all_end[feat_idx]
offset = features["offset_mapping"][feat_idx]
start_idx = int(np.argmax(start_logit))
end_idx = int(np.argmax(end_logit))
if start_idx <= end_idx < len(offset):
start_char, _ = offset[start_idx]
_, end_char = offset[end_idx]
span = context[start_char:end_char].strip()
score = start_logit[start_idx] + end_logit[end_idx]
if score > best_score and span:
best_score, best_span = score, span
predictions.append(
{"id": example["id"], "prediction_text": best_span}
)
return predictions
def compute_metrics(eval_pred: EvalPrediction):
predictions = postprocess_qa(raw_val, val_feats, eval_pred.predictions, tok)
references = [
{"id": ex["id"], "answers": ex["answers"]} for ex in raw_val
]
return metric.compute(predictions=predictions, references=references)
# ───────────────────────────────────────────────────────────────── main ──
def main():
set_seed(SEED)
# ο£Ώ model name to store on Hub
model_repo = os.getenv("MODEL_NAME", "AvocadoMuffin/roberta-cuad-qa-v2")
if (tokn := os.getenv("roberta_token")):
try: login(tokn); print("πŸ”‘ HuggingFace Hub login OK")
except Exception as e: print("Hub login failed:", e)
print("πŸ“š Loading CUAD…")
cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True)
cuad = cuad.shuffle(seed=SEED)
cuad = balance_has_answer(cuad, ratio=2.0) # β‰ˆ18 k rows
# train / val 90-10
ds = cuad.train_test_split(test_size=0.1, seed=SEED)
train_raw, val_raw = ds["train"], ds["test"]
# ── tokeniser & model (SQuAD-2 tuned) ───────────────────────────────
base_ckpt = "deepset/roberta-base-squad2"
tok = AutoTokenizer.from_pretrained(base_ckpt, use_fast=True)
model = AutoModelForQuestionAnswering.from_pretrained(base_ckpt)
# LoRA
lora = LoraConfig(
task_type=TaskType.QUESTION_ANS,
r=16, lora_alpha=32, lora_dropout=0.05,
target_modules=["query", "value"],
)
model = get_peft_model(model, lora)
model.print_trainable_parameters()
# ── preprocess ─────────────────────────────────────────────────────
def preprocess(examples):
return tok(
examples["question"],
examples["context"],
truncation="only_second",
max_length=MAX_LEN,
stride=DOC_STRIDE,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
) | { "example_id": examples["id"] }
train_feats = train_raw.map(
preprocess, batched=True, remove_columns=train_raw.column_names,
num_proc=4, desc="tokenise-train"
)
val_feats = val_raw.map(
preprocess, batched=True, remove_columns=val_raw.column_names,
num_proc=4, desc="tokenise-val"
)
global raw_val # for metric fn
raw_val = val_raw
# ── training args ──────────────────────────────────────────────────
args = TrainingArguments(
output_dir="./cuad_lora_out",
learning_rate=3e-5,
num_train_epochs=4,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
gradient_accumulation_steps=4, # eff. BS 32
fp16=False, bf16=True, # L4 = bf16
evaluation_strategy="steps",
eval_steps=250,
save_steps=500,
save_total_limit=2,
weight_decay=0.01,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
load_best_model_at_end=True,
metric_for_best_model="f1",
greater_is_better=True,
logging_steps=50,
report_to="none",
)
trainer = QuestionAnsweringTrainer(
model=model,
args=args,
train_dataset=train_feats,
eval_dataset=val_feats,
tokenizer=tok,
data_collator=default_data_collator,
compute_metrics=compute_metrics,
)
print("πŸš€ Training…")
trainer.train()
print("βœ… Done. Best F1:", trainer.state.best_metric)
trainer.save_model("./cuad_lora_out")
tok.save_pretrained("./cuad_lora_out")
# optional: push
if tokn:
trainer.push_to_hub(model_repo, private=False)
tok.push_to_hub(model_repo, private=False)
print("πŸš€ Pushed to:", f"https://huggingface.co/{model_repo}")
if __name__ == "__main__":
main()