Create train_script.py
Browse files- train_script.py +100 -0
    	
        train_script.py
    ADDED
    
    | @@ -0,0 +1,100 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from datasets import load_dataset
         | 
| 2 | 
            +
            from sentence_transformers import (
         | 
| 3 | 
            +
                SparseEncoder,
         | 
| 4 | 
            +
                SparseEncoderTrainer,
         | 
| 5 | 
            +
                SparseEncoderTrainingArguments,
         | 
| 6 | 
            +
                SparseEncoderModelCardData,
         | 
| 7 | 
            +
            )
         | 
| 8 | 
            +
            from sentence_transformers.sparse_encoder.losses import SpladeLoss, SparseMultipleNegativesRankingLoss
         | 
| 9 | 
            +
            from sentence_transformers.training_args import BatchSamplers
         | 
| 10 | 
            +
            from sentence_transformers.sparse_encoder.evaluation import SparseNanoBEIREvaluator
         | 
| 11 | 
            +
            from sentence_transformers.sparse_encoder.models import SpladePooling, MLMTransformer, IDF
         | 
| 12 | 
            +
            from sentence_transformers.models import Asym
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            import logging
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            # 1. Load a model to finetune with 2. (Optional) model card data
         | 
| 19 | 
            +
            mlm_transformer = MLMTransformer("prajjwal1/bert-tiny", tokenizer_args={"model_max_length": 512})
         | 
| 20 | 
            +
            splade_pooling = SpladePooling(pooling_strategy="max", word_embedding_dimension=mlm_transformer.get_sentence_embedding_dimension())
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            asym = Asym({
         | 
| 23 | 
            +
                "query": [IDF(tokenizer=mlm_transformer.tokenizer, frozen=False)],
         | 
| 24 | 
            +
                "document": [mlm_transformer, splade_pooling],
         | 
| 25 | 
            +
            })
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            model = SparseEncoder(
         | 
| 28 | 
            +
                modules=[asym],
         | 
| 29 | 
            +
                model_card_data=SparseEncoderModelCardData(
         | 
| 30 | 
            +
                    language="en",
         | 
| 31 | 
            +
                    license="apache-2.0",
         | 
| 32 | 
            +
                    model_name="Inference-free SPLADE BERT-tiny trained on Natural-Questions tuples",
         | 
| 33 | 
            +
                )
         | 
| 34 | 
            +
            )
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            # 3. Load a dataset to finetune on
         | 
| 37 | 
            +
            full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000))
         | 
| 38 | 
            +
            # full_dataset = full_dataset.map(lambda sample: {"query": {"query": sample["query"]}, "corpus": {"corpus": sample["answer"]}}, remove_columns=["query", "answer"])
         | 
| 39 | 
            +
            dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
         | 
| 40 | 
            +
            train_dataset = dataset_dict["train"]
         | 
| 41 | 
            +
            eval_dataset = dataset_dict["test"]
         | 
| 42 | 
            +
            print(train_dataset)
         | 
| 43 | 
            +
            print(train_dataset[0])
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            # 4. Define a loss function
         | 
| 46 | 
            +
            loss = SpladeLoss(
         | 
| 47 | 
            +
                model=model,
         | 
| 48 | 
            +
                loss=SparseMultipleNegativesRankingLoss(model=model),
         | 
| 49 | 
            +
                lambda_query=0,
         | 
| 50 | 
            +
                lambda_corpus=3e-2,
         | 
| 51 | 
            +
            )
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            # 5. (Optional) Specify training arguments
         | 
| 54 | 
            +
            run_name = "inference-free-splade-bert-tiny-nq-fresh-3e-2-lambda-corpus-1e-3-idf-lr-2e-5-lr"
         | 
| 55 | 
            +
            args = SparseEncoderTrainingArguments(
         | 
| 56 | 
            +
                # Required parameter:
         | 
| 57 | 
            +
                output_dir=f"models/{run_name}",
         | 
| 58 | 
            +
                # Optional training parameters:
         | 
| 59 | 
            +
                num_train_epochs=1,
         | 
| 60 | 
            +
                per_device_train_batch_size=64,
         | 
| 61 | 
            +
                per_device_eval_batch_size=64,
         | 
| 62 | 
            +
                learning_rate=2e-5,
         | 
| 63 | 
            +
                learning_rate_mapping={"IDF\.weight": 1e-3},  # Set a higher learning rate for the IDF module
         | 
| 64 | 
            +
                warmup_ratio=0.1,
         | 
| 65 | 
            +
                fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
         | 
| 66 | 
            +
                bf16=False,  # Set to True if you have a GPU that supports BF16
         | 
| 67 | 
            +
                batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
         | 
| 68 | 
            +
                router_mapping=["query", "document"],
         | 
| 69 | 
            +
                # Optional tracking/debugging parameters:
         | 
| 70 | 
            +
                eval_strategy="steps",
         | 
| 71 | 
            +
                eval_steps=200,
         | 
| 72 | 
            +
                save_strategy="steps",
         | 
| 73 | 
            +
                save_steps=200,
         | 
| 74 | 
            +
                save_total_limit=2,
         | 
| 75 | 
            +
                logging_steps=20,
         | 
| 76 | 
            +
                run_name=run_name,  # Will be used in W&B if `wandb` is installed
         | 
| 77 | 
            +
            )
         | 
| 78 | 
            +
             | 
| 79 | 
            +
            # 6. (Optional) Create an evaluator & evaluate the base model
         | 
| 80 | 
            +
            dev_evaluator = SparseNanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=128)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
            # 7. Create a trainer & train
         | 
| 83 | 
            +
            trainer = SparseEncoderTrainer(
         | 
| 84 | 
            +
                model=model,
         | 
| 85 | 
            +
                args=args,
         | 
| 86 | 
            +
                train_dataset=train_dataset,
         | 
| 87 | 
            +
                eval_dataset=eval_dataset,
         | 
| 88 | 
            +
                loss=loss,
         | 
| 89 | 
            +
                evaluator=dev_evaluator,
         | 
| 90 | 
            +
            )
         | 
| 91 | 
            +
            trainer.train()
         | 
| 92 | 
            +
             | 
| 93 | 
            +
            # 8. Evaluate the model performance again after training
         | 
| 94 | 
            +
            dev_evaluator(model)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            # 9. Save the trained model
         | 
| 97 | 
            +
            model.save_pretrained(f"models/{run_name}/final")
         | 
| 98 | 
            +
             | 
| 99 | 
            +
            # 10. (Optional) Push it to the Hugging Face Hub
         | 
| 100 | 
            +
            model.push_to_hub(run_name)
         | 

