Merge pull request #159 from AngainorDev/patch-1
Browse files- scripts/finetune.py +4 -5
- src/axolotl/utils/models.py +4 -9
scripts/finetune.py
CHANGED
|
@@ -165,7 +165,7 @@ def train(
|
|
| 165 |
cfg_keys = cfg.keys()
|
| 166 |
for k, _ in kwargs.items():
|
| 167 |
# if not strict, allow writing to cfg even if it's not in the yml already
|
| 168 |
-
if k in cfg_keys or cfg.strict
|
| 169 |
# handle booleans
|
| 170 |
if isinstance(cfg[k], bool):
|
| 171 |
cfg[k] = bool(kwargs[k])
|
|
@@ -205,8 +205,8 @@ def train(
|
|
| 205 |
logging.info(f"loading tokenizer... {tokenizer_config}")
|
| 206 |
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
| 207 |
|
| 208 |
-
if
|
| 209 |
-
["
|
| 210 |
): # don't need to load dataset for these
|
| 211 |
train_dataset, eval_dataset = load_prepare_datasets(
|
| 212 |
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
|
@@ -234,7 +234,6 @@ def train(
|
|
| 234 |
tokenizer,
|
| 235 |
cfg,
|
| 236 |
adapter=cfg.adapter,
|
| 237 |
-
inference=("inference" in kwargs),
|
| 238 |
)
|
| 239 |
|
| 240 |
if "merge_lora" in kwargs and cfg.adapter is not None:
|
|
@@ -247,7 +246,7 @@ def train(
|
|
| 247 |
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
| 248 |
return
|
| 249 |
|
| 250 |
-
if
|
| 251 |
logging.info("calling do_inference function")
|
| 252 |
inf_kwargs: Dict[str, Any] = {}
|
| 253 |
if "prompter" in kwargs:
|
|
|
|
| 165 |
cfg_keys = cfg.keys()
|
| 166 |
for k, _ in kwargs.items():
|
| 167 |
# if not strict, allow writing to cfg even if it's not in the yml already
|
| 168 |
+
if k in cfg_keys or not cfg.strict:
|
| 169 |
# handle booleans
|
| 170 |
if isinstance(cfg[k], bool):
|
| 171 |
cfg[k] = bool(kwargs[k])
|
|
|
|
| 205 |
logging.info(f"loading tokenizer... {tokenizer_config}")
|
| 206 |
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
| 207 |
|
| 208 |
+
if (
|
| 209 |
+
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
| 210 |
): # don't need to load dataset for these
|
| 211 |
train_dataset, eval_dataset = load_prepare_datasets(
|
| 212 |
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
|
|
|
| 234 |
tokenizer,
|
| 235 |
cfg,
|
| 236 |
adapter=cfg.adapter,
|
|
|
|
| 237 |
)
|
| 238 |
|
| 239 |
if "merge_lora" in kwargs and cfg.adapter is not None:
|
|
|
|
| 246 |
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
| 247 |
return
|
| 248 |
|
| 249 |
+
if cfg.inference:
|
| 250 |
logging.info("calling do_inference function")
|
| 251 |
inf_kwargs: Dict[str, Any] = {}
|
| 252 |
if "prompter" in kwargs:
|
src/axolotl/utils/models.py
CHANGED
|
@@ -77,15 +77,9 @@ def load_tokenizer(
|
|
| 77 |
|
| 78 |
|
| 79 |
def load_model(
|
| 80 |
-
base_model,
|
| 81 |
-
base_model_config,
|
| 82 |
-
model_type,
|
| 83 |
-
tokenizer,
|
| 84 |
-
cfg,
|
| 85 |
-
adapter="lora",
|
| 86 |
-
inference=False,
|
| 87 |
):
|
| 88 |
-
# type: (str, str, str, AutoTokenizer, DictDefault, Optional[str]
|
| 89 |
"""
|
| 90 |
Load a model from a base model and a model type.
|
| 91 |
"""
|
|
@@ -98,7 +92,7 @@ def load_model(
|
|
| 98 |
)
|
| 99 |
|
| 100 |
if cfg.is_llama_derived_model and cfg.flash_attention:
|
| 101 |
-
if cfg.device not in ["mps", "cpu"] and inference
|
| 102 |
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
| 103 |
|
| 104 |
logging.info("patching with flash attention")
|
|
@@ -439,6 +433,7 @@ def load_lora(model, cfg):
|
|
| 439 |
model = PeftModel.from_pretrained(
|
| 440 |
model,
|
| 441 |
cfg.lora_model_dir,
|
|
|
|
| 442 |
device_map=cfg.device_map,
|
| 443 |
# torch_dtype=torch.float16,
|
| 444 |
)
|
|
|
|
| 77 |
|
| 78 |
|
| 79 |
def load_model(
|
| 80 |
+
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
):
|
| 82 |
+
# type: (str, str, str, AutoTokenizer, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
| 83 |
"""
|
| 84 |
Load a model from a base model and a model type.
|
| 85 |
"""
|
|
|
|
| 92 |
)
|
| 93 |
|
| 94 |
if cfg.is_llama_derived_model and cfg.flash_attention:
|
| 95 |
+
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
|
| 96 |
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
| 97 |
|
| 98 |
logging.info("patching with flash attention")
|
|
|
|
| 433 |
model = PeftModel.from_pretrained(
|
| 434 |
model,
|
| 435 |
cfg.lora_model_dir,
|
| 436 |
+
is_trainable=not cfg.inference,
|
| 437 |
device_map=cfg.device_map,
|
| 438 |
# torch_dtype=torch.float16,
|
| 439 |
)
|