make sure to cleanup tmp output_dir for e2e tests
Browse files- tests/e2e/test_fused_llama.py +3 -3
- tests/e2e/test_lora_llama.py +7 -7
- tests/e2e/test_mistral.py +5 -5
- tests/e2e/test_mistral_samplepack.py +5 -5
- tests/e2e/test_phi.py +10 -5
- tests/utils.py +22 -0
tests/e2e/test_fused_llama.py
CHANGED
|
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
|
| 4 |
|
| 5 |
import logging
|
| 6 |
import os
|
| 7 |
-
import tempfile
|
| 8 |
import unittest
|
| 9 |
from pathlib import Path
|
| 10 |
|
|
@@ -15,6 +14,7 @@ from axolotl.common.cli import TrainerCliArgs
|
|
| 15 |
from axolotl.train import train
|
| 16 |
from axolotl.utils.config import normalize_config
|
| 17 |
from axolotl.utils.dict import DictDefault
|
|
|
|
| 18 |
|
| 19 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
| 20 |
os.environ["WANDB_DISABLED"] = "true"
|
|
@@ -25,9 +25,9 @@ class TestFusedLlama(unittest.TestCase):
|
|
| 25 |
Test case for Llama models using Fused layers
|
| 26 |
"""
|
| 27 |
|
| 28 |
-
|
|
|
|
| 29 |
# pylint: disable=duplicate-code
|
| 30 |
-
output_dir = tempfile.mkdtemp()
|
| 31 |
cfg = DictDefault(
|
| 32 |
{
|
| 33 |
"base_model": "JackFram/llama-68m",
|
|
|
|
| 4 |
|
| 5 |
import logging
|
| 6 |
import os
|
|
|
|
| 7 |
import unittest
|
| 8 |
from pathlib import Path
|
| 9 |
|
|
|
|
| 14 |
from axolotl.train import train
|
| 15 |
from axolotl.utils.config import normalize_config
|
| 16 |
from axolotl.utils.dict import DictDefault
|
| 17 |
+
from tests.utils import with_temp_dir
|
| 18 |
|
| 19 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
| 20 |
os.environ["WANDB_DISABLED"] = "true"
|
|
|
|
| 25 |
Test case for Llama models using Fused layers
|
| 26 |
"""
|
| 27 |
|
| 28 |
+
@with_temp_dir
|
| 29 |
+
def test_fft_packing(self, output_dir):
|
| 30 |
# pylint: disable=duplicate-code
|
|
|
|
| 31 |
cfg = DictDefault(
|
| 32 |
{
|
| 33 |
"base_model": "JackFram/llama-68m",
|
tests/e2e/test_lora_llama.py
CHANGED
|
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
|
| 4 |
|
| 5 |
import logging
|
| 6 |
import os
|
| 7 |
-
import tempfile
|
| 8 |
import unittest
|
| 9 |
from pathlib import Path
|
| 10 |
|
|
@@ -13,6 +12,7 @@ from axolotl.common.cli import TrainerCliArgs
|
|
| 13 |
from axolotl.train import train
|
| 14 |
from axolotl.utils.config import normalize_config
|
| 15 |
from axolotl.utils.dict import DictDefault
|
|
|
|
| 16 |
|
| 17 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
| 18 |
os.environ["WANDB_DISABLED"] = "true"
|
|
@@ -23,9 +23,9 @@ class TestLoraLlama(unittest.TestCase):
|
|
| 23 |
Test case for Llama models using LoRA
|
| 24 |
"""
|
| 25 |
|
| 26 |
-
|
|
|
|
| 27 |
# pylint: disable=duplicate-code
|
| 28 |
-
output_dir = tempfile.mkdtemp()
|
| 29 |
cfg = DictDefault(
|
| 30 |
{
|
| 31 |
"base_model": "JackFram/llama-68m",
|
|
@@ -65,9 +65,9 @@ class TestLoraLlama(unittest.TestCase):
|
|
| 65 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 66 |
assert (Path(output_dir) / "adapter_model.bin").exists()
|
| 67 |
|
| 68 |
-
|
|
|
|
| 69 |
# pylint: disable=duplicate-code
|
| 70 |
-
output_dir = tempfile.mkdtemp()
|
| 71 |
cfg = DictDefault(
|
| 72 |
{
|
| 73 |
"base_model": "JackFram/llama-68m",
|
|
@@ -109,9 +109,9 @@ class TestLoraLlama(unittest.TestCase):
|
|
| 109 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 110 |
assert (Path(output_dir) / "adapter_model.bin").exists()
|
| 111 |
|
| 112 |
-
|
|
|
|
| 113 |
# pylint: disable=duplicate-code
|
| 114 |
-
output_dir = tempfile.mkdtemp()
|
| 115 |
cfg = DictDefault(
|
| 116 |
{
|
| 117 |
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
|
|
|
|
| 4 |
|
| 5 |
import logging
|
| 6 |
import os
|
|
|
|
| 7 |
import unittest
|
| 8 |
from pathlib import Path
|
| 9 |
|
|
|
|
| 12 |
from axolotl.train import train
|
| 13 |
from axolotl.utils.config import normalize_config
|
| 14 |
from axolotl.utils.dict import DictDefault
|
| 15 |
+
from tests.utils import with_temp_dir
|
| 16 |
|
| 17 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
| 18 |
os.environ["WANDB_DISABLED"] = "true"
|
|
|
|
| 23 |
Test case for Llama models using LoRA
|
| 24 |
"""
|
| 25 |
|
| 26 |
+
@with_temp_dir
|
| 27 |
+
def test_lora(self, output_dir):
|
| 28 |
# pylint: disable=duplicate-code
|
|
|
|
| 29 |
cfg = DictDefault(
|
| 30 |
{
|
| 31 |
"base_model": "JackFram/llama-68m",
|
|
|
|
| 65 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 66 |
assert (Path(output_dir) / "adapter_model.bin").exists()
|
| 67 |
|
| 68 |
+
@with_temp_dir
|
| 69 |
+
def test_lora_packing(self, output_dir):
|
| 70 |
# pylint: disable=duplicate-code
|
|
|
|
| 71 |
cfg = DictDefault(
|
| 72 |
{
|
| 73 |
"base_model": "JackFram/llama-68m",
|
|
|
|
| 109 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 110 |
assert (Path(output_dir) / "adapter_model.bin").exists()
|
| 111 |
|
| 112 |
+
@with_temp_dir
|
| 113 |
+
def test_lora_gptq(self, output_dir):
|
| 114 |
# pylint: disable=duplicate-code
|
|
|
|
| 115 |
cfg = DictDefault(
|
| 116 |
{
|
| 117 |
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
|
tests/e2e/test_mistral.py
CHANGED
|
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
|
| 4 |
|
| 5 |
import logging
|
| 6 |
import os
|
| 7 |
-
import tempfile
|
| 8 |
import unittest
|
| 9 |
from pathlib import Path
|
| 10 |
|
|
@@ -15,6 +14,7 @@ from axolotl.common.cli import TrainerCliArgs
|
|
| 15 |
from axolotl.train import train
|
| 16 |
from axolotl.utils.config import normalize_config
|
| 17 |
from axolotl.utils.dict import DictDefault
|
|
|
|
| 18 |
|
| 19 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
| 20 |
os.environ["WANDB_DISABLED"] = "true"
|
|
@@ -25,9 +25,9 @@ class TestMistral(unittest.TestCase):
|
|
| 25 |
Test case for Llama models using LoRA
|
| 26 |
"""
|
| 27 |
|
| 28 |
-
|
|
|
|
| 29 |
# pylint: disable=duplicate-code
|
| 30 |
-
output_dir = tempfile.mkdtemp()
|
| 31 |
cfg = DictDefault(
|
| 32 |
{
|
| 33 |
"base_model": "openaccess-ai-collective/tiny-mistral",
|
|
@@ -70,9 +70,9 @@ class TestMistral(unittest.TestCase):
|
|
| 70 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 71 |
assert (Path(output_dir) / "adapter_model.bin").exists()
|
| 72 |
|
| 73 |
-
|
|
|
|
| 74 |
# pylint: disable=duplicate-code
|
| 75 |
-
output_dir = tempfile.mkdtemp()
|
| 76 |
cfg = DictDefault(
|
| 77 |
{
|
| 78 |
"base_model": "openaccess-ai-collective/tiny-mistral",
|
|
|
|
| 4 |
|
| 5 |
import logging
|
| 6 |
import os
|
|
|
|
| 7 |
import unittest
|
| 8 |
from pathlib import Path
|
| 9 |
|
|
|
|
| 14 |
from axolotl.train import train
|
| 15 |
from axolotl.utils.config import normalize_config
|
| 16 |
from axolotl.utils.dict import DictDefault
|
| 17 |
+
from tests.utils import with_temp_dir
|
| 18 |
|
| 19 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
| 20 |
os.environ["WANDB_DISABLED"] = "true"
|
|
|
|
| 25 |
Test case for Llama models using LoRA
|
| 26 |
"""
|
| 27 |
|
| 28 |
+
@with_temp_dir
|
| 29 |
+
def test_lora(self, output_dir):
|
| 30 |
# pylint: disable=duplicate-code
|
|
|
|
| 31 |
cfg = DictDefault(
|
| 32 |
{
|
| 33 |
"base_model": "openaccess-ai-collective/tiny-mistral",
|
|
|
|
| 70 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 71 |
assert (Path(output_dir) / "adapter_model.bin").exists()
|
| 72 |
|
| 73 |
+
@with_temp_dir
|
| 74 |
+
def test_ft(self, output_dir):
|
| 75 |
# pylint: disable=duplicate-code
|
|
|
|
| 76 |
cfg = DictDefault(
|
| 77 |
{
|
| 78 |
"base_model": "openaccess-ai-collective/tiny-mistral",
|
tests/e2e/test_mistral_samplepack.py
CHANGED
|
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
|
| 4 |
|
| 5 |
import logging
|
| 6 |
import os
|
| 7 |
-
import tempfile
|
| 8 |
import unittest
|
| 9 |
from pathlib import Path
|
| 10 |
|
|
@@ -15,6 +14,7 @@ from axolotl.common.cli import TrainerCliArgs
|
|
| 15 |
from axolotl.train import train
|
| 16 |
from axolotl.utils.config import normalize_config
|
| 17 |
from axolotl.utils.dict import DictDefault
|
|
|
|
| 18 |
|
| 19 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
| 20 |
os.environ["WANDB_DISABLED"] = "true"
|
|
@@ -25,9 +25,9 @@ class TestMistral(unittest.TestCase):
|
|
| 25 |
Test case for Llama models using LoRA
|
| 26 |
"""
|
| 27 |
|
| 28 |
-
|
|
|
|
| 29 |
# pylint: disable=duplicate-code
|
| 30 |
-
output_dir = tempfile.mkdtemp()
|
| 31 |
cfg = DictDefault(
|
| 32 |
{
|
| 33 |
"base_model": "openaccess-ai-collective/tiny-mistral",
|
|
@@ -71,9 +71,9 @@ class TestMistral(unittest.TestCase):
|
|
| 71 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 72 |
assert (Path(output_dir) / "adapter_model.bin").exists()
|
| 73 |
|
| 74 |
-
|
|
|
|
| 75 |
# pylint: disable=duplicate-code
|
| 76 |
-
output_dir = tempfile.mkdtemp()
|
| 77 |
cfg = DictDefault(
|
| 78 |
{
|
| 79 |
"base_model": "openaccess-ai-collective/tiny-mistral",
|
|
|
|
| 4 |
|
| 5 |
import logging
|
| 6 |
import os
|
|
|
|
| 7 |
import unittest
|
| 8 |
from pathlib import Path
|
| 9 |
|
|
|
|
| 14 |
from axolotl.train import train
|
| 15 |
from axolotl.utils.config import normalize_config
|
| 16 |
from axolotl.utils.dict import DictDefault
|
| 17 |
+
from tests.utils import with_temp_dir
|
| 18 |
|
| 19 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
| 20 |
os.environ["WANDB_DISABLED"] = "true"
|
|
|
|
| 25 |
Test case for Llama models using LoRA
|
| 26 |
"""
|
| 27 |
|
| 28 |
+
@with_temp_dir
|
| 29 |
+
def test_lora_packing(self, output_dir):
|
| 30 |
# pylint: disable=duplicate-code
|
|
|
|
| 31 |
cfg = DictDefault(
|
| 32 |
{
|
| 33 |
"base_model": "openaccess-ai-collective/tiny-mistral",
|
|
|
|
| 71 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 72 |
assert (Path(output_dir) / "adapter_model.bin").exists()
|
| 73 |
|
| 74 |
+
@with_temp_dir
|
| 75 |
+
def test_ft_packing(self, output_dir):
|
| 76 |
# pylint: disable=duplicate-code
|
|
|
|
| 77 |
cfg = DictDefault(
|
| 78 |
{
|
| 79 |
"base_model": "openaccess-ai-collective/tiny-mistral",
|
tests/e2e/test_phi.py
CHANGED
|
@@ -4,14 +4,15 @@ E2E tests for lora llama
|
|
| 4 |
|
| 5 |
import logging
|
| 6 |
import os
|
| 7 |
-
import tempfile
|
| 8 |
import unittest
|
|
|
|
| 9 |
|
| 10 |
from axolotl.cli import load_datasets
|
| 11 |
from axolotl.common.cli import TrainerCliArgs
|
| 12 |
from axolotl.train import train
|
| 13 |
from axolotl.utils.config import normalize_config
|
| 14 |
from axolotl.utils.dict import DictDefault
|
|
|
|
| 15 |
|
| 16 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
| 17 |
os.environ["WANDB_DISABLED"] = "true"
|
|
@@ -22,7 +23,8 @@ class TestPhi(unittest.TestCase):
|
|
| 22 |
Test case for Llama models using LoRA
|
| 23 |
"""
|
| 24 |
|
| 25 |
-
|
|
|
|
| 26 |
# pylint: disable=duplicate-code
|
| 27 |
cfg = DictDefault(
|
| 28 |
{
|
|
@@ -52,7 +54,7 @@ class TestPhi(unittest.TestCase):
|
|
| 52 |
"num_epochs": 1,
|
| 53 |
"micro_batch_size": 1,
|
| 54 |
"gradient_accumulation_steps": 1,
|
| 55 |
-
"output_dir":
|
| 56 |
"learning_rate": 0.00001,
|
| 57 |
"optimizer": "adamw_bnb_8bit",
|
| 58 |
"lr_scheduler": "cosine",
|
|
@@ -64,8 +66,10 @@ class TestPhi(unittest.TestCase):
|
|
| 64 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
| 65 |
|
| 66 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
|
|
| 67 |
|
| 68 |
-
|
|
|
|
| 69 |
# pylint: disable=duplicate-code
|
| 70 |
cfg = DictDefault(
|
| 71 |
{
|
|
@@ -95,7 +99,7 @@ class TestPhi(unittest.TestCase):
|
|
| 95 |
"num_epochs": 1,
|
| 96 |
"micro_batch_size": 1,
|
| 97 |
"gradient_accumulation_steps": 1,
|
| 98 |
-
"output_dir":
|
| 99 |
"learning_rate": 0.00001,
|
| 100 |
"optimizer": "adamw_bnb_8bit",
|
| 101 |
"lr_scheduler": "cosine",
|
|
@@ -107,3 +111,4 @@ class TestPhi(unittest.TestCase):
|
|
| 107 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
| 108 |
|
| 109 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
|
|
|
|
| 4 |
|
| 5 |
import logging
|
| 6 |
import os
|
|
|
|
| 7 |
import unittest
|
| 8 |
+
from pathlib import Path
|
| 9 |
|
| 10 |
from axolotl.cli import load_datasets
|
| 11 |
from axolotl.common.cli import TrainerCliArgs
|
| 12 |
from axolotl.train import train
|
| 13 |
from axolotl.utils.config import normalize_config
|
| 14 |
from axolotl.utils.dict import DictDefault
|
| 15 |
+
from tests.utils import with_temp_dir
|
| 16 |
|
| 17 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
| 18 |
os.environ["WANDB_DISABLED"] = "true"
|
|
|
|
| 23 |
Test case for Llama models using LoRA
|
| 24 |
"""
|
| 25 |
|
| 26 |
+
@with_temp_dir
|
| 27 |
+
def test_ft(self, output_dir):
|
| 28 |
# pylint: disable=duplicate-code
|
| 29 |
cfg = DictDefault(
|
| 30 |
{
|
|
|
|
| 54 |
"num_epochs": 1,
|
| 55 |
"micro_batch_size": 1,
|
| 56 |
"gradient_accumulation_steps": 1,
|
| 57 |
+
"output_dir": output_dir,
|
| 58 |
"learning_rate": 0.00001,
|
| 59 |
"optimizer": "adamw_bnb_8bit",
|
| 60 |
"lr_scheduler": "cosine",
|
|
|
|
| 66 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
| 67 |
|
| 68 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 69 |
+
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
| 70 |
|
| 71 |
+
@with_temp_dir
|
| 72 |
+
def test_ft_packed(self, output_dir):
|
| 73 |
# pylint: disable=duplicate-code
|
| 74 |
cfg = DictDefault(
|
| 75 |
{
|
|
|
|
| 99 |
"num_epochs": 1,
|
| 100 |
"micro_batch_size": 1,
|
| 101 |
"gradient_accumulation_steps": 1,
|
| 102 |
+
"output_dir": output_dir,
|
| 103 |
"learning_rate": 0.00001,
|
| 104 |
"optimizer": "adamw_bnb_8bit",
|
| 105 |
"lr_scheduler": "cosine",
|
|
|
|
| 111 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
| 112 |
|
| 113 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 114 |
+
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
tests/utils.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
helper utils for tests
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import shutil
|
| 6 |
+
import tempfile
|
| 7 |
+
from functools import wraps
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def with_temp_dir(test_func):
|
| 11 |
+
@wraps(test_func)
|
| 12 |
+
def wrapper(*args, **kwargs):
|
| 13 |
+
# Create a temporary directory
|
| 14 |
+
temp_dir = tempfile.mkdtemp()
|
| 15 |
+
try:
|
| 16 |
+
# Pass the temporary directory to the test function
|
| 17 |
+
test_func(temp_dir, *args, **kwargs)
|
| 18 |
+
finally:
|
| 19 |
+
# Clean up the directory after the test
|
| 20 |
+
shutil.rmtree(temp_dir)
|
| 21 |
+
|
| 22 |
+
return wrapper
|