|
|
|
|
|
|
|
import argparse |
|
|
|
from datasets import load_dataset |
|
from sentence_transformers import ( |
|
SentenceTransformer, |
|
SentenceTransformerTrainer, |
|
SentenceTransformerTrainingArguments, |
|
) |
|
from sentence_transformers.evaluation import NanoBEIREvaluator |
|
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss |
|
from sentence_transformers.training_args import BatchSamplers |
|
|
|
def main(): |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--lr", type=float, default=8e-5) |
|
parser.add_argument("--model_name", type=str, default="answerdotai/ModernBERT-base") |
|
args = parser.parse_args() |
|
lr = args.lr |
|
model_name = args.model_name |
|
model_shortname = model_name.split("/")[-1] |
|
|
|
|
|
model = SentenceTransformer(model_name) |
|
model.max_seq_length = 8192 |
|
|
|
|
|
dataset = load_dataset("sentence-transformers/gooaq", split="train") |
|
dataset_dict = dataset.train_test_split(test_size=1_000, seed=12) |
|
train_dataset = dataset_dict["train"] |
|
eval_dataset = dataset_dict["test"] |
|
|
|
|
|
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=128) |
|
|
|
run_name = f"{model_shortname}-gooaq-{lr}" |
|
|
|
args = SentenceTransformerTrainingArguments( |
|
|
|
output_dir=f"output/{model_shortname}/{run_name}", |
|
|
|
num_train_epochs=1, |
|
per_device_train_batch_size=2048, |
|
per_device_eval_batch_size=2048, |
|
learning_rate=lr, |
|
warmup_ratio=0.05, |
|
fp16=False, |
|
bf16=True, |
|
batch_sampler=BatchSamplers.NO_DUPLICATES, |
|
|
|
eval_strategy="steps", |
|
eval_steps=50, |
|
save_strategy="steps", |
|
save_steps=50, |
|
save_total_limit=2, |
|
logging_steps=10, |
|
run_name=run_name, |
|
) |
|
|
|
|
|
dev_evaluator = NanoBEIREvaluator(dataset_names=["NQ", "MSMARCO"]) |
|
dev_evaluator(model) |
|
|
|
|
|
trainer = SentenceTransformerTrainer( |
|
model=model, |
|
args=args, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
loss=loss, |
|
evaluator=dev_evaluator, |
|
) |
|
trainer.train() |
|
|
|
|
|
dev_evaluator(model) |
|
|
|
|
|
model.save_pretrained(f"output/{model_shortname}/{run_name}/final") |
|
|
|
|
|
model.push_to_hub(run_name, private=False) |
|
|
|
if __name__ == "__main__": |
|
main() |