Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# train_cuad_lora.py | |
""" | |
CUAD fine-tune with LoRA on an L4 / T4 GPU. | |
Expected wall-clock on Nvidia L4: ~25-30 min. | |
""" | |
import os, json, random, gc | |
from collections import defaultdict | |
import torch, numpy as np | |
from datasets import load_dataset, Dataset, disable_caching | |
from transformers import ( | |
AutoTokenizer, AutoModelForQuestionAnswering, | |
TrainingArguments, default_data_collator | |
) | |
from transformers import QuestionAnsweringTrainer, EvalPrediction | |
from peft import LoraConfig, get_peft_model, TaskType | |
import evaluate | |
from huggingface_hub import login | |
disable_caching() # avoids giant disk cache on Colab | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ helpers ββ | |
MAX_LEN = 384 # window | |
DOC_STRIDE = 128 | |
SEED = 42 | |
def set_seed(seed): | |
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
def balance_has_answer(dataset, ratio=2.0): | |
"""Keep all has-answer rows, down-sample no-answer rows to `ratio`.""" | |
has, no = [], [] | |
for ex in dataset: | |
(has if ex["answers"]["text"] else no).append(ex) | |
k = int(len(has) * ratio) | |
no = random.sample(no, min(k, len(no))) | |
return Dataset.from_list(has + no) | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ postproc ββ | |
metric = evaluate.load("squad") | |
def postprocess_qa(examples, features, raw_predictions, tokenizer): | |
"""HF-style span extraction + n-best, returns SQuAD format dict.""" | |
all_start, all_end = raw_predictions | |
example_id_to_index = {k: i for i, k in enumerate(examples["id"])} | |
features_per_example = defaultdict(list) | |
for i, feat_id in enumerate(features["example_id"]): | |
features_per_example[example_id_to_index[feat_id]].append(i) | |
predictions = [] | |
for example_idx, example in enumerate(examples): | |
best_score = -1e9 | |
best_span = "" | |
context = example["context"] | |
for feat_idx in features_per_example[example_idx]: | |
start_logit = all_start[feat_idx] | |
end_logit = all_end[feat_idx] | |
offset = features["offset_mapping"][feat_idx] | |
start_idx = int(np.argmax(start_logit)) | |
end_idx = int(np.argmax(end_logit)) | |
if start_idx <= end_idx < len(offset): | |
start_char, _ = offset[start_idx] | |
_, end_char = offset[end_idx] | |
span = context[start_char:end_char].strip() | |
score = start_logit[start_idx] + end_logit[end_idx] | |
if score > best_score and span: | |
best_score, best_span = score, span | |
predictions.append( | |
{"id": example["id"], "prediction_text": best_span} | |
) | |
return predictions | |
def compute_metrics(eval_pred: EvalPrediction): | |
predictions = postprocess_qa(raw_val, val_feats, eval_pred.predictions, tok) | |
references = [ | |
{"id": ex["id"], "answers": ex["answers"]} for ex in raw_val | |
] | |
return metric.compute(predictions=predictions, references=references) | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ main ββ | |
def main(): | |
set_seed(SEED) | |
# ο£Ώ model name to store on Hub | |
model_repo = os.getenv("MODEL_NAME", "AvocadoMuffin/roberta-cuad-qa-v2") | |
if (tokn := os.getenv("roberta_token")): | |
try: login(tokn); print("π HuggingFace Hub login OK") | |
except Exception as e: print("Hub login failed:", e) | |
print("π Loading CUADβ¦") | |
cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True) | |
cuad = cuad.shuffle(seed=SEED) | |
cuad = balance_has_answer(cuad, ratio=2.0) # β18 k rows | |
# train / val 90-10 | |
ds = cuad.train_test_split(test_size=0.1, seed=SEED) | |
train_raw, val_raw = ds["train"], ds["test"] | |
# ββ tokeniser & model (SQuAD-2 tuned) βββββββββββββββββββββββββββββββ | |
base_ckpt = "deepset/roberta-base-squad2" | |
tok = AutoTokenizer.from_pretrained(base_ckpt, use_fast=True) | |
model = AutoModelForQuestionAnswering.from_pretrained(base_ckpt) | |
# LoRA | |
lora = LoraConfig( | |
task_type=TaskType.QUESTION_ANS, | |
r=16, lora_alpha=32, lora_dropout=0.05, | |
target_modules=["query", "value"], | |
) | |
model = get_peft_model(model, lora) | |
model.print_trainable_parameters() | |
# ββ preprocess βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def preprocess(examples): | |
return tok( | |
examples["question"], | |
examples["context"], | |
truncation="only_second", | |
max_length=MAX_LEN, | |
stride=DOC_STRIDE, | |
return_overflowing_tokens=True, | |
return_offsets_mapping=True, | |
padding="max_length", | |
) | { "example_id": examples["id"] } | |
train_feats = train_raw.map( | |
preprocess, batched=True, remove_columns=train_raw.column_names, | |
num_proc=4, desc="tokenise-train" | |
) | |
val_feats = val_raw.map( | |
preprocess, batched=True, remove_columns=val_raw.column_names, | |
num_proc=4, desc="tokenise-val" | |
) | |
global raw_val # for metric fn | |
raw_val = val_raw | |
# ββ training args ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
args = TrainingArguments( | |
output_dir="./cuad_lora_out", | |
learning_rate=3e-5, | |
num_train_epochs=4, | |
per_device_train_batch_size=8, | |
per_device_eval_batch_size=8, | |
gradient_accumulation_steps=4, # eff. BS 32 | |
fp16=False, bf16=True, # L4 = bf16 | |
evaluation_strategy="steps", | |
eval_steps=250, | |
save_steps=500, | |
save_total_limit=2, | |
weight_decay=0.01, | |
lr_scheduler_type="cosine", | |
warmup_ratio=0.1, | |
load_best_model_at_end=True, | |
metric_for_best_model="f1", | |
greater_is_better=True, | |
logging_steps=50, | |
report_to="none", | |
) | |
trainer = QuestionAnsweringTrainer( | |
model=model, | |
args=args, | |
train_dataset=train_feats, | |
eval_dataset=val_feats, | |
tokenizer=tok, | |
data_collator=default_data_collator, | |
compute_metrics=compute_metrics, | |
) | |
print("π Trainingβ¦") | |
trainer.train() | |
print("β Done. Best F1:", trainer.state.best_metric) | |
trainer.save_model("./cuad_lora_out") | |
tok.save_pretrained("./cuad_lora_out") | |
# optional: push | |
if tokn: | |
trainer.push_to_hub(model_repo, private=False) | |
tok.push_to_hub(model_repo, private=False) | |
print("π Pushed to:", f"https://huggingface.co/{model_repo}") | |
if __name__ == "__main__": | |
main() | |