move is_llama_derived_model into normalize_config (#524)
Browse files- scripts/finetune.py +1 -10
- src/axolotl/utils/config.py +11 -0
scripts/finetune.py
CHANGED
|
@@ -24,7 +24,7 @@ from axolotl.utils.config import normalize_config, validate_config
|
|
| 24 |
from axolotl.utils.data import prepare_dataset
|
| 25 |
from axolotl.utils.dict import DictDefault
|
| 26 |
from axolotl.utils.distributed import is_main_process
|
| 27 |
-
from axolotl.utils.models import
|
| 28 |
from axolotl.utils.tokenization import check_dataset_labels
|
| 29 |
from axolotl.utils.wandb import setup_wandb_env_vars
|
| 30 |
|
|
@@ -216,15 +216,6 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
|
| 216 |
else:
|
| 217 |
cfg[k] = kwargs[k]
|
| 218 |
|
| 219 |
-
model_config = load_model_config(cfg)
|
| 220 |
-
|
| 221 |
-
# figure out if the model is llama
|
| 222 |
-
cfg.is_llama_derived_model = (
|
| 223 |
-
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
| 224 |
-
or cfg.is_llama_derived_model
|
| 225 |
-
or "llama" in cfg.base_model
|
| 226 |
-
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
| 227 |
-
)
|
| 228 |
validate_config(cfg)
|
| 229 |
|
| 230 |
normalize_config(cfg)
|
|
|
|
| 24 |
from axolotl.utils.data import prepare_dataset
|
| 25 |
from axolotl.utils.dict import DictDefault
|
| 26 |
from axolotl.utils.distributed import is_main_process
|
| 27 |
+
from axolotl.utils.models import load_tokenizer
|
| 28 |
from axolotl.utils.tokenization import check_dataset_labels
|
| 29 |
from axolotl.utils.wandb import setup_wandb_env_vars
|
| 30 |
|
|
|
|
| 216 |
else:
|
| 217 |
cfg[k] = kwargs[k]
|
| 218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
validate_config(cfg)
|
| 220 |
|
| 221 |
normalize_config(cfg)
|
src/axolotl/utils/config.py
CHANGED
|
@@ -6,6 +6,7 @@ import os
|
|
| 6 |
import torch
|
| 7 |
|
| 8 |
from axolotl.utils.bench import log_gpu_memory_usage
|
|
|
|
| 9 |
|
| 10 |
LOG = logging.getLogger("axolotl")
|
| 11 |
|
|
@@ -69,6 +70,16 @@ def normalize_config(cfg):
|
|
| 69 |
else:
|
| 70 |
cfg.torch_dtype = torch.float32
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
| 73 |
|
| 74 |
|
|
|
|
| 6 |
import torch
|
| 7 |
|
| 8 |
from axolotl.utils.bench import log_gpu_memory_usage
|
| 9 |
+
from axolotl.utils.models import load_model_config
|
| 10 |
|
| 11 |
LOG = logging.getLogger("axolotl")
|
| 12 |
|
|
|
|
| 70 |
else:
|
| 71 |
cfg.torch_dtype = torch.float32
|
| 72 |
|
| 73 |
+
model_config = load_model_config(cfg)
|
| 74 |
+
|
| 75 |
+
# figure out if the model is llama
|
| 76 |
+
cfg.is_llama_derived_model = (
|
| 77 |
+
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
| 78 |
+
or cfg.is_llama_derived_model
|
| 79 |
+
or "llama" in cfg.base_model
|
| 80 |
+
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
| 84 |
|
| 85 |
|