Support for additional_special_tokens (#1221) [skip ci]
Browse files* Support for additional_special_tokens
* Support for additional_special_tokens. Adjust whitespace.
* Support for additional_special_tokens. Use correct quotes.
* Support for additional_special_tokens. Safe pop.
* Support for additional_special_tokens. nt.
* Support for additional_special_tokens. cfg.special_tokens may be None.
* add token if not in vocabulary when adding additional_special_tokens
* fix logic for copy/pasta
* bugfix for popping from config and tokenizer reload
* no need to add tokens manually now with previous bugfix
---------
Co-authored-by: Wing Lian <[email protected]>
- src/axolotl/utils/models.py +22 -2
- tests/test_tokenizers.py +15 -0
    	
        src/axolotl/utils/models.py
    CHANGED
    
    | @@ -161,15 +161,20 @@ def load_tokenizer(cfg): | |
| 161 | 
             
                        if getattr(tokenizer, attr_name) is None:
         | 
| 162 | 
             
                            setattr(tokenizer, attr_name, "<|endoftext|>")
         | 
| 163 |  | 
|  | |
| 164 | 
             
                if cfg.special_tokens:
         | 
|  | |
|  | |
|  | |
|  | |
| 165 | 
             
                    lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
         | 
| 166 | 
            -
                    for k, val in  | 
| 167 | 
             
                        # check if new special token is not already in tokenizer and
         | 
| 168 | 
             
                        # is adapter training to make sure lora_modules_to_save is set
         | 
| 169 | 
             
                        # pylint: disable=too-many-boolean-expressions
         | 
| 170 | 
             
                        if (
         | 
| 171 | 
             
                            (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
         | 
| 172 | 
            -
                            and (len(tokenizer.encode(val)) >  | 
| 173 | 
             
                            and cfg.adapter
         | 
| 174 | 
             
                            and (
         | 
| 175 | 
             
                                not cfg.lora_modules_to_save
         | 
| @@ -213,6 +218,21 @@ def load_tokenizer(cfg): | |
| 213 | 
             
                        ]
         | 
| 214 | 
             
                    )
         | 
| 215 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 216 | 
             
                LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
         | 
| 217 | 
             
                LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
         | 
| 218 | 
             
                LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
         | 
|  | |
| 161 | 
             
                        if getattr(tokenizer, attr_name) is None:
         | 
| 162 | 
             
                            setattr(tokenizer, attr_name, "<|endoftext|>")
         | 
| 163 |  | 
| 164 | 
            +
                additional_special_tokens = None
         | 
| 165 | 
             
                if cfg.special_tokens:
         | 
| 166 | 
            +
                    special_tokens = cfg.special_tokens.to_dict()
         | 
| 167 | 
            +
                    additional_special_tokens = special_tokens.pop(
         | 
| 168 | 
            +
                        "additional_special_tokens", None
         | 
| 169 | 
            +
                    )
         | 
| 170 | 
             
                    lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
         | 
| 171 | 
            +
                    for k, val in special_tokens.items():
         | 
| 172 | 
             
                        # check if new special token is not already in tokenizer and
         | 
| 173 | 
             
                        # is adapter training to make sure lora_modules_to_save is set
         | 
| 174 | 
             
                        # pylint: disable=too-many-boolean-expressions
         | 
| 175 | 
             
                        if (
         | 
| 176 | 
             
                            (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
         | 
| 177 | 
            +
                            and (len(tokenizer.encode(val, add_special_tokens=False)) > 2)
         | 
| 178 | 
             
                            and cfg.adapter
         | 
| 179 | 
             
                            and (
         | 
| 180 | 
             
                                not cfg.lora_modules_to_save
         | 
|  | |
| 218 | 
             
                        ]
         | 
| 219 | 
             
                    )
         | 
| 220 |  | 
| 221 | 
            +
                # Additional special tokens are a List, and need to be treated differently than regular special
         | 
| 222 | 
            +
                # tokens. We add them after we have called `add_tokens` in case these additional special tokens
         | 
| 223 | 
            +
                # are new tokens.
         | 
| 224 | 
            +
                #
         | 
| 225 | 
            +
                # Usage:
         | 
| 226 | 
            +
                #
         | 
| 227 | 
            +
                # ```py
         | 
| 228 | 
            +
                # special_tokens:
         | 
| 229 | 
            +
                #   additional_special_tokens: ["<|im_start|>", "<|im_end|>"]
         | 
| 230 | 
            +
                # ```
         | 
| 231 | 
            +
                if additional_special_tokens is not None:
         | 
| 232 | 
            +
                    tokenizer.add_special_tokens(
         | 
| 233 | 
            +
                        {"additional_special_tokens": additional_special_tokens}
         | 
| 234 | 
            +
                    )
         | 
| 235 | 
            +
             | 
| 236 | 
             
                LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
         | 
| 237 | 
             
                LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
         | 
| 238 | 
             
                LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
         | 
    	
        tests/test_tokenizers.py
    CHANGED
    
    | @@ -67,6 +67,21 @@ class TestTokenizers(unittest.TestCase): | |
| 67 | 
             
                    )
         | 
| 68 | 
             
                    load_tokenizer(cfg)
         | 
| 69 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 70 |  | 
| 71 | 
             
            if __name__ == "__main__":
         | 
| 72 | 
             
                unittest.main()
         | 
|  | |
| 67 | 
             
                    )
         | 
| 68 | 
             
                    load_tokenizer(cfg)
         | 
| 69 |  | 
| 70 | 
            +
                def test_add_additional_special_tokens(self):
         | 
| 71 | 
            +
                    cfg = DictDefault(
         | 
| 72 | 
            +
                        {
         | 
| 73 | 
            +
                            "tokenizer_config": "huggyllama/llama-7b",
         | 
| 74 | 
            +
                            "special_tokens": {"additional_special_tokens": ["<|im_start|>"]},
         | 
| 75 | 
            +
                        }
         | 
| 76 | 
            +
                    )
         | 
| 77 | 
            +
                    tokenizer = load_tokenizer(cfg)
         | 
| 78 | 
            +
                    self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
         | 
| 79 | 
            +
                    self.assertEqual(len(tokenizer), 32001)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    # ensure reloading the tokenizer again from cfg results in same vocab length
         | 
| 82 | 
            +
                    tokenizer = load_tokenizer(cfg)
         | 
| 83 | 
            +
                    self.assertEqual(len(tokenizer), 32001)
         | 
| 84 | 
            +
             | 
| 85 |  | 
| 86 | 
             
            if __name__ == "__main__":
         | 
| 87 | 
             
                unittest.main()
         | 
