File size: 3,758 Bytes
606b2d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
from datasets import load_dataset
from sentence_transformers import (
SparseEncoder,
SparseEncoderTrainer,
SparseEncoderTrainingArguments,
SparseEncoderModelCardData,
)
from sentence_transformers.sparse_encoder.losses import SpladeLoss, SparseMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.sparse_encoder.evaluation import SparseNanoBEIREvaluator
from sentence_transformers.sparse_encoder.models import SpladePooling, MLMTransformer, IDF
from sentence_transformers.models import Asym
import logging
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
# 1. Load a model to finetune with 2. (Optional) model card data
mlm_transformer = MLMTransformer("bert-base-uncased", tokenizer_args={"model_max_length": 512})
splade_pooling = SpladePooling(pooling_strategy="max", word_embedding_dimension=mlm_transformer.get_sentence_embedding_dimension())
asym = Asym({
"query": [IDF(tokenizer=mlm_transformer.tokenizer, frozen=False)],
"document": [mlm_transformer, splade_pooling],
})
model = SparseEncoder(
modules=[asym],
model_card_data=SparseEncoderModelCardData(
language="en",
license="apache-2.0",
model_name="Inference-free SPLADE bert-base-uncased trained on Natural-Questions tuples",
)
)
# 3. Load a dataset to finetune on
full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000))
# full_dataset = full_dataset.map(lambda sample: {"query": {"query": sample["query"]}, "corpus": {"corpus": sample["answer"]}}, remove_columns=["query", "answer"])
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
print(train_dataset)
print(train_dataset[0])
# 4. Define a loss function
loss = SpladeLoss(
model=model,
loss=SparseMultipleNegativesRankingLoss(model=model),
lambda_query=0,
lambda_corpus=3e-3,
)
# 5. (Optional) Specify training arguments
run_name = "inference-free-splade-bert-base-uncased-nq-3e-3-lambda-corpus-1e-3-idf-lr-2e-5-lr"
args = SparseEncoderTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
learning_rate=2e-5,
learning_rate_mapping={"IDF\.weight": 1e-3}, # Set a higher learning rate for the IDF module
warmup_ratio=0.1,
fp16=True, # Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
router_mapping=["query", "document"],
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=400,
save_strategy="steps",
save_steps=400,
save_total_limit=2,
logging_steps=200,
run_name=run_name, # Will be used in W&B if `wandb` is installed
)
# 6. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = SparseNanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=16)
# 7. Create a trainer & train
trainer = SparseEncoderTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
evaluator=dev_evaluator,
)
trainer.train()
# 8. Evaluate the model performance again after training
dev_evaluator(model)
# 9. Save the trained model
model.save_pretrained(f"models/{run_name}/final")
# 10. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name) |