from datasets import load_dataset from sentence_transformers import ( SparseEncoder, SparseEncoderTrainer, SparseEncoderTrainingArguments, SparseEncoderModelCardData, ) from sentence_transformers.sparse_encoder.losses import SpladeLoss, SparseMultipleNegativesRankingLoss from sentence_transformers.training_args import BatchSamplers from sentence_transformers.sparse_encoder.evaluation import SparseNanoBEIREvaluator from sentence_transformers.sparse_encoder.models import SpladePooling, MLMTransformer, IDF from sentence_transformers.models import Asym import logging logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) # 1. Load a model to finetune with 2. (Optional) model card data mlm_transformer = MLMTransformer("bert-base-uncased", tokenizer_args={"model_max_length": 512}) splade_pooling = SpladePooling(pooling_strategy="max", word_embedding_dimension=mlm_transformer.get_sentence_embedding_dimension()) asym = Asym({ "query": [IDF(tokenizer=mlm_transformer.tokenizer, frozen=False)], "document": [mlm_transformer, splade_pooling], }) model = SparseEncoder( modules=[asym], model_card_data=SparseEncoderModelCardData( language="en", license="apache-2.0", model_name="Inference-free SPLADE bert-base-uncased trained on Natural-Questions tuples", ) ) # 3. Load a dataset to finetune on full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000)) # full_dataset = full_dataset.map(lambda sample: {"query": {"query": sample["query"]}, "corpus": {"corpus": sample["answer"]}}, remove_columns=["query", "answer"]) dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12) train_dataset = dataset_dict["train"] eval_dataset = dataset_dict["test"] print(train_dataset) print(train_dataset[0]) # 4. Define a loss function loss = SpladeLoss( model=model, loss=SparseMultipleNegativesRankingLoss(model=model), lambda_query=0, lambda_corpus=3e-3, ) # 5. (Optional) Specify training arguments run_name = "inference-free-splade-bert-base-uncased-nq-3e-3-lambda-corpus-1e-3-idf-lr-2e-5-lr" args = SparseEncoderTrainingArguments( # Required parameter: output_dir=f"models/{run_name}", # Optional training parameters: num_train_epochs=1, per_device_train_batch_size=16, per_device_eval_batch_size=16, learning_rate=2e-5, learning_rate_mapping={"IDF\.weight": 1e-3}, # Set a higher learning rate for the IDF module warmup_ratio=0.1, fp16=True, # Set to False if you get an error that your GPU can't run on FP16 bf16=False, # Set to True if you have a GPU that supports BF16 batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch router_mapping=["query", "document"], # Optional tracking/debugging parameters: eval_strategy="steps", eval_steps=400, save_strategy="steps", save_steps=400, save_total_limit=2, logging_steps=200, run_name=run_name, # Will be used in W&B if `wandb` is installed ) # 6. (Optional) Create an evaluator & evaluate the base model dev_evaluator = SparseNanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=16) # 7. Create a trainer & train trainer = SparseEncoderTrainer( model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, loss=loss, evaluator=dev_evaluator, ) trainer.train() # 8. Evaluate the model performance again after training dev_evaluator(model) # 9. Save the trained model model.save_pretrained(f"models/{run_name}/final") # 10. (Optional) Push it to the Hugging Face Hub model.push_to_hub(run_name)