workaround for md5 variations (#533)
Browse files* workaround for md5 variations
* refactor the prepared hash too
- src/axolotl/utils/data.py +15 -13
- tests/test_data.py +64 -0
    	
        src/axolotl/utils/data.py
    CHANGED
    
    | @@ -2,7 +2,6 @@ | |
| 2 | 
             
            import functools
         | 
| 3 | 
             
            import hashlib
         | 
| 4 | 
             
            import logging
         | 
| 5 | 
            -
            from hashlib import md5
         | 
| 6 | 
             
            from pathlib import Path
         | 
| 7 | 
             
            from typing import Tuple, Union
         | 
| 8 |  | 
| @@ -52,6 +51,13 @@ LOG = logging.getLogger("axolotl") | |
| 52 | 
             
            DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
         | 
| 53 |  | 
| 54 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 55 | 
             
            def prepare_dataset(cfg, tokenizer):
         | 
| 56 | 
             
                if not cfg.pretraining_dataset:
         | 
| 57 | 
             
                    with zero_first(is_main_process()):
         | 
| @@ -88,7 +94,7 @@ def load_tokenized_prepared_datasets( | |
| 88 | 
             
            ) -> DatasetDict:
         | 
| 89 | 
             
                tokenizer_name = tokenizer.__class__.__name__
         | 
| 90 | 
             
                ds_hash = str(
         | 
| 91 | 
            -
                    md5( | 
| 92 | 
             
                        (
         | 
| 93 | 
             
                            str(cfg.sequence_len)
         | 
| 94 | 
             
                            + "@"
         | 
| @@ -97,8 +103,8 @@ def load_tokenized_prepared_datasets( | |
| 97 | 
             
                            )
         | 
| 98 | 
             
                            + "|"
         | 
| 99 | 
             
                            + tokenizer_name
         | 
| 100 | 
            -
                        ) | 
| 101 | 
            -
                    ) | 
| 102 | 
             
                )
         | 
| 103 | 
             
                prepared_ds_path = (
         | 
| 104 | 
             
                    Path(cfg.dataset_prepared_path) / ds_hash
         | 
| @@ -374,7 +380,7 @@ def load_prepare_datasets( | |
| 374 | 
             
                    # see if we can go ahead and load the stacked dataset
         | 
| 375 | 
             
                    seed = f"@{str(cfg.seed)}" if cfg.seed else ""
         | 
| 376 | 
             
                    ds_hash = str(
         | 
| 377 | 
            -
                        md5( | 
| 378 | 
             
                            (
         | 
| 379 | 
             
                                str(cfg.sequence_len)
         | 
| 380 | 
             
                                + "@"
         | 
| @@ -385,8 +391,8 @@ def load_prepare_datasets( | |
| 385 | 
             
                                )
         | 
| 386 | 
             
                                + "|"
         | 
| 387 | 
             
                                + tokenizer_name
         | 
| 388 | 
            -
                            ) | 
| 389 | 
            -
                        ) | 
| 390 | 
             
                    )
         | 
| 391 | 
             
                    prepared_ds_path = (
         | 
| 392 | 
             
                        Path(cfg.dataset_prepared_path) / ds_hash
         | 
| @@ -500,12 +506,8 @@ def load_prepare_datasets( | |
| 500 | 
             
                        + "|"
         | 
| 501 | 
             
                        + str(cfg.seed or 42)
         | 
| 502 | 
             
                    )
         | 
| 503 | 
            -
                    train_fingerprint =  | 
| 504 | 
            -
             | 
| 505 | 
            -
                    ).hexdigest()
         | 
| 506 | 
            -
                    test_fingerprint = hashlib.md5(
         | 
| 507 | 
            -
                        to_hash_test.encode(), usedforsecurity=False
         | 
| 508 | 
            -
                    ).hexdigest()
         | 
| 509 |  | 
| 510 | 
             
                    with zero_first(is_main_process()):
         | 
| 511 | 
             
                        dataset = dataset.train_test_split(
         | 
|  | |
| 2 | 
             
            import functools
         | 
| 3 | 
             
            import hashlib
         | 
| 4 | 
             
            import logging
         | 
|  | |
| 5 | 
             
            from pathlib import Path
         | 
| 6 | 
             
            from typing import Tuple, Union
         | 
| 7 |  | 
|  | |
| 51 | 
             
            DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
         | 
| 52 |  | 
| 53 |  | 
| 54 | 
            +
            def md5(to_hash: str, encoding: str = "utf-8") -> str:
         | 
| 55 | 
            +
                try:
         | 
| 56 | 
            +
                    return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
         | 
| 57 | 
            +
                except TypeError:
         | 
| 58 | 
            +
                    return hashlib.md5(to_hash.encode(encoding)).hexdigest()  # nosec
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
             
            def prepare_dataset(cfg, tokenizer):
         | 
| 62 | 
             
                if not cfg.pretraining_dataset:
         | 
| 63 | 
             
                    with zero_first(is_main_process()):
         | 
|  | |
| 94 | 
             
            ) -> DatasetDict:
         | 
| 95 | 
             
                tokenizer_name = tokenizer.__class__.__name__
         | 
| 96 | 
             
                ds_hash = str(
         | 
| 97 | 
            +
                    md5(
         | 
| 98 | 
             
                        (
         | 
| 99 | 
             
                            str(cfg.sequence_len)
         | 
| 100 | 
             
                            + "@"
         | 
|  | |
| 103 | 
             
                            )
         | 
| 104 | 
             
                            + "|"
         | 
| 105 | 
             
                            + tokenizer_name
         | 
| 106 | 
            +
                        )
         | 
| 107 | 
            +
                    )
         | 
| 108 | 
             
                )
         | 
| 109 | 
             
                prepared_ds_path = (
         | 
| 110 | 
             
                    Path(cfg.dataset_prepared_path) / ds_hash
         | 
|  | |
| 380 | 
             
                    # see if we can go ahead and load the stacked dataset
         | 
| 381 | 
             
                    seed = f"@{str(cfg.seed)}" if cfg.seed else ""
         | 
| 382 | 
             
                    ds_hash = str(
         | 
| 383 | 
            +
                        md5(
         | 
| 384 | 
             
                            (
         | 
| 385 | 
             
                                str(cfg.sequence_len)
         | 
| 386 | 
             
                                + "@"
         | 
|  | |
| 391 | 
             
                                )
         | 
| 392 | 
             
                                + "|"
         | 
| 393 | 
             
                                + tokenizer_name
         | 
| 394 | 
            +
                            )
         | 
| 395 | 
            +
                        )
         | 
| 396 | 
             
                    )
         | 
| 397 | 
             
                    prepared_ds_path = (
         | 
| 398 | 
             
                        Path(cfg.dataset_prepared_path) / ds_hash
         | 
|  | |
| 506 | 
             
                        + "|"
         | 
| 507 | 
             
                        + str(cfg.seed or 42)
         | 
| 508 | 
             
                    )
         | 
| 509 | 
            +
                    train_fingerprint = md5(to_hash_train)
         | 
| 510 | 
            +
                    test_fingerprint = md5(to_hash_test)
         | 
|  | |
|  | |
|  | |
|  | |
| 511 |  | 
| 512 | 
             
                    with zero_first(is_main_process()):
         | 
| 513 | 
             
                        dataset = dataset.train_test_split(
         | 
    	
        tests/test_data.py
    ADDED
    
    | @@ -0,0 +1,64 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            test module for the axolotl.utis.data module
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
            import unittest
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from transformers import LlamaTokenizer
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from axolotl.utils.data import encode_pretraining, md5
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class TestEncodePretraining(unittest.TestCase):
         | 
| 12 | 
            +
                """
         | 
| 13 | 
            +
                test class for encode pretraining and md5 helper
         | 
| 14 | 
            +
                """
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                def setUp(self):
         | 
| 17 | 
            +
                    self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
         | 
| 18 | 
            +
                    self.tokenizer.add_special_tokens(
         | 
| 19 | 
            +
                        {
         | 
| 20 | 
            +
                            "eos_token": "</s>",
         | 
| 21 | 
            +
                            "bos_token": "<s>",
         | 
| 22 | 
            +
                            "unk_token": "<unk>",
         | 
| 23 | 
            +
                            "pad_token": "<pad>",
         | 
| 24 | 
            +
                        }
         | 
| 25 | 
            +
                    )
         | 
| 26 | 
            +
                    self.max_tokens = 15  # set a small number for easy inspection
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                def test_encode_pretraining(self):
         | 
| 29 | 
            +
                    examples = {
         | 
| 30 | 
            +
                        "text": [
         | 
| 31 | 
            +
                            "Hello, world!",
         | 
| 32 | 
            +
                            "Nice to meet you.",
         | 
| 33 | 
            +
                            "lorem ipsum dolor sit amet.",
         | 
| 34 | 
            +
                            "Nice to meet you again!.",
         | 
| 35 | 
            +
                            "hello, hello",
         | 
| 36 | 
            +
                        ]
         | 
| 37 | 
            +
                    }
         | 
| 38 | 
            +
                    result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    self.assertEqual(len(result["input_ids"]), 3)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    # Assert the length of input_ids and attention_mask is correct
         | 
| 43 | 
            +
                    self.assertEqual(len(result["input_ids"][0]), self.max_tokens)
         | 
| 44 | 
            +
                    self.assertEqual(len(result["attention_mask"][0]), self.max_tokens)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    # Assert EOS and PAD tokens are correctly added
         | 
| 47 | 
            +
                    # hello world! is 4 tokens
         | 
| 48 | 
            +
                    self.assertEqual(result["input_ids"][0][0], self.tokenizer.bos_token_id)
         | 
| 49 | 
            +
                    self.assertEqual(result["input_ids"][0][5], self.tokenizer.eos_token_id)
         | 
| 50 | 
            +
                    self.assertEqual(result["input_ids"][0][6], self.tokenizer.pad_token_id)
         | 
| 51 | 
            +
                    # second part, 5 tokens
         | 
| 52 | 
            +
                    self.assertEqual(result["input_ids"][0][7], self.tokenizer.bos_token_id)
         | 
| 53 | 
            +
                    self.assertEqual(result["input_ids"][0][13], self.tokenizer.eos_token_id)
         | 
| 54 | 
            +
                    self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def test_md5(self):
         | 
| 57 | 
            +
                    self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3")
         | 
| 58 | 
            +
                    self.assertEqual(
         | 
| 59 | 
            +
                        md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
         | 
| 60 | 
            +
                    )
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            if __name__ == "__main__":
         | 
| 64 | 
            +
                unittest.main()
         | 
