Frank Ruis
wrap prepared_ds_path in str() to avoid TypeError in fsspec package (#1548)
7477a53
unverified
| """ | |
| Test dataset loading under various conditions. | |
| """ | |
| import shutil | |
| import tempfile | |
| import unittest | |
| from pathlib import Path | |
| from datasets import Dataset | |
| from huggingface_hub import snapshot_download | |
| from transformers import AutoTokenizer | |
| from axolotl.utils.data import load_tokenized_prepared_datasets | |
| from axolotl.utils.dict import DictDefault | |
| class TestDatasetPreparation(unittest.TestCase): | |
| """Test a configured dataloader.""" | |
| def setUp(self) -> None: | |
| self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") | |
| self.tokenizer.add_special_tokens( | |
| { | |
| "bos_token": "<s>", | |
| "eos_token": "</s>", | |
| "unk_token": "<unk>", | |
| } | |
| ) | |
| # Alpaca dataset. | |
| self.dataset = Dataset.from_list( | |
| [ | |
| { | |
| "instruction": "Evaluate this sentence for spelling and grammar mistakes", | |
| "input": "He finnished his meal and left the resturant", | |
| "output": "He finished his meal and left the restaurant.", | |
| } | |
| ] | |
| ) | |
| def test_load_hub(self): | |
| """Core use case. Verify that processing data from the hub works""" | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| prepared_path = Path(tmp_dir) / "prepared" | |
| cfg = DictDefault( | |
| { | |
| "tokenizer_config": "huggyllama/llama-7b", | |
| "sequence_len": 1024, | |
| "datasets": [ | |
| { | |
| "path": "mhenrichsen/alpaca_2k_test", | |
| "type": "alpaca", | |
| }, | |
| ], | |
| } | |
| ) | |
| dataset, _ = load_tokenized_prepared_datasets( | |
| self.tokenizer, cfg, prepared_path | |
| ) | |
| assert len(dataset) == 2000 | |
| assert "input_ids" in dataset.features | |
| assert "attention_mask" in dataset.features | |
| assert "labels" in dataset.features | |
| def test_load_local_hub(self): | |
| """Niche use case. Verify that a local copy of a hub dataset can be loaded""" | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| tmp_ds_path = Path("mhenrichsen/alpaca_2k_test") | |
| tmp_ds_path.mkdir(parents=True, exist_ok=True) | |
| snapshot_download( | |
| repo_id="mhenrichsen/alpaca_2k_test", | |
| repo_type="dataset", | |
| local_dir=tmp_ds_path, | |
| ) | |
| prepared_path = Path(tmp_dir) / "prepared" | |
| # Right now a local copy that doesn't fully conform to a dataset | |
| # must list data_files and ds_type otherwise the loader won't know | |
| # how to load it. | |
| cfg = DictDefault( | |
| { | |
| "tokenizer_config": "huggyllama/llama-7b", | |
| "sequence_len": 1024, | |
| "datasets": [ | |
| { | |
| "path": "mhenrichsen/alpaca_2k_test", | |
| "ds_type": "parquet", | |
| "type": "alpaca", | |
| "data_files": [ | |
| "mhenrichsen/alpaca_2k_test/alpaca_2000.parquet", | |
| ], | |
| }, | |
| ], | |
| } | |
| ) | |
| dataset, _ = load_tokenized_prepared_datasets( | |
| self.tokenizer, cfg, prepared_path | |
| ) | |
| assert len(dataset) == 2000 | |
| assert "input_ids" in dataset.features | |
| assert "attention_mask" in dataset.features | |
| assert "labels" in dataset.features | |
| shutil.rmtree(tmp_ds_path) | |
| def test_load_from_save_to_disk(self): | |
| """Usual use case. Verify datasets saved via `save_to_disk` can be loaded.""" | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| tmp_ds_name = Path(tmp_dir) / "tmp_dataset" | |
| self.dataset.save_to_disk(str(tmp_ds_name)) | |
| prepared_path = Path(tmp_dir) / "prepared" | |
| cfg = DictDefault( | |
| { | |
| "tokenizer_config": "huggyllama/llama-7b", | |
| "sequence_len": 256, | |
| "datasets": [ | |
| { | |
| "path": str(tmp_ds_name), | |
| "type": "alpaca", | |
| }, | |
| ], | |
| } | |
| ) | |
| dataset, _ = load_tokenized_prepared_datasets( | |
| self.tokenizer, cfg, prepared_path | |
| ) | |
| assert len(dataset) == 1 | |
| assert "input_ids" in dataset.features | |
| assert "attention_mask" in dataset.features | |
| assert "labels" in dataset.features | |
| def test_load_from_dir_of_parquet(self): | |
| """Usual use case. Verify a directory of parquet files can be loaded.""" | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| tmp_ds_dir = Path(tmp_dir) / "tmp_dataset" | |
| tmp_ds_dir.mkdir() | |
| tmp_ds_path = tmp_ds_dir / "shard1.parquet" | |
| self.dataset.to_parquet(tmp_ds_path) | |
| prepared_path: Path = Path(tmp_dir) / "prepared" | |
| cfg = DictDefault( | |
| { | |
| "tokenizer_config": "huggyllama/llama-7b", | |
| "sequence_len": 256, | |
| "datasets": [ | |
| { | |
| "path": str(tmp_ds_dir), | |
| "ds_type": "parquet", | |
| "name": "test_data", | |
| "data_files": [ | |
| str(tmp_ds_path), | |
| ], | |
| "type": "alpaca", | |
| }, | |
| ], | |
| } | |
| ) | |
| dataset, _ = load_tokenized_prepared_datasets( | |
| self.tokenizer, cfg, prepared_path | |
| ) | |
| assert len(dataset) == 1 | |
| assert "input_ids" in dataset.features | |
| assert "attention_mask" in dataset.features | |
| assert "labels" in dataset.features | |
| def test_load_from_dir_of_json(self): | |
| """Standard use case. Verify a directory of json files can be loaded.""" | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| tmp_ds_dir = Path(tmp_dir) / "tmp_dataset" | |
| tmp_ds_dir.mkdir() | |
| tmp_ds_path = tmp_ds_dir / "shard1.json" | |
| self.dataset.to_json(tmp_ds_path) | |
| prepared_path: Path = Path(tmp_dir) / "prepared" | |
| cfg = DictDefault( | |
| { | |
| "tokenizer_config": "huggyllama/llama-7b", | |
| "sequence_len": 256, | |
| "datasets": [ | |
| { | |
| "path": str(tmp_ds_dir), | |
| "ds_type": "json", | |
| "name": "test_data", | |
| "data_files": [ | |
| str(tmp_ds_path), | |
| ], | |
| "type": "alpaca", | |
| }, | |
| ], | |
| } | |
| ) | |
| dataset, _ = load_tokenized_prepared_datasets( | |
| self.tokenizer, cfg, prepared_path | |
| ) | |
| assert len(dataset) == 1 | |
| assert "input_ids" in dataset.features | |
| assert "attention_mask" in dataset.features | |
| assert "labels" in dataset.features | |
| def test_load_from_single_parquet(self): | |
| """Standard use case. Verify a single parquet file can be loaded.""" | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet" | |
| self.dataset.to_parquet(tmp_ds_path) | |
| prepared_path: Path = Path(tmp_dir) / "prepared" | |
| cfg = DictDefault( | |
| { | |
| "tokenizer_config": "huggyllama/llama-7b", | |
| "sequence_len": 256, | |
| "datasets": [ | |
| { | |
| "path": str(tmp_ds_path), | |
| "name": "test_data", | |
| "type": "alpaca", | |
| }, | |
| ], | |
| } | |
| ) | |
| dataset, _ = load_tokenized_prepared_datasets( | |
| self.tokenizer, cfg, prepared_path | |
| ) | |
| assert len(dataset) == 1 | |
| assert "input_ids" in dataset.features | |
| assert "attention_mask" in dataset.features | |
| assert "labels" in dataset.features | |
| def test_load_from_single_json(self): | |
| """Standard use case. Verify a single json file can be loaded.""" | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json" | |
| self.dataset.to_json(tmp_ds_path) | |
| prepared_path: Path = Path(tmp_dir) / "prepared" | |
| cfg = DictDefault( | |
| { | |
| "tokenizer_config": "huggyllama/llama-7b", | |
| "sequence_len": 256, | |
| "datasets": [ | |
| { | |
| "path": str(tmp_ds_path), | |
| "name": "test_data", | |
| "type": "alpaca", | |
| }, | |
| ], | |
| } | |
| ) | |
| dataset, _ = load_tokenized_prepared_datasets( | |
| self.tokenizer, cfg, prepared_path | |
| ) | |
| assert len(dataset) == 1 | |
| assert "input_ids" in dataset.features | |
| assert "attention_mask" in dataset.features | |
| assert "labels" in dataset.features | |
| if __name__ == "__main__": | |
| unittest.main() | |