Angainor Development
commited on
Feed cfg.inference
Browse files- scripts/finetune.py +7 -5
scripts/finetune.py
CHANGED
|
@@ -182,6 +182,9 @@ def train(
|
|
| 182 |
if cfg.bf16:
|
| 183 |
cfg.fp16 = True
|
| 184 |
cfg.bf16 = False
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
# load the tokenizer first
|
| 187 |
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
|
@@ -189,8 +192,8 @@ def train(
|
|
| 189 |
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
| 190 |
|
| 191 |
if check_not_in(
|
| 192 |
-
["
|
| 193 |
-
): # don't need to load dataset for these
|
| 194 |
train_dataset, eval_dataset = load_prepare_datasets(
|
| 195 |
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
| 196 |
)
|
|
@@ -216,8 +219,7 @@ def train(
|
|
| 216 |
cfg.model_type,
|
| 217 |
tokenizer,
|
| 218 |
cfg,
|
| 219 |
-
adapter=cfg.adapter
|
| 220 |
-
inference=("inference" in kwargs),
|
| 221 |
)
|
| 222 |
|
| 223 |
if "merge_lora" in kwargs and cfg.adapter is not None:
|
|
@@ -230,7 +232,7 @@ def train(
|
|
| 230 |
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
| 231 |
return
|
| 232 |
|
| 233 |
-
if
|
| 234 |
logging.info("calling do_inference function")
|
| 235 |
do_inference(cfg, model, tokenizer)
|
| 236 |
return
|
|
|
|
| 182 |
if cfg.bf16:
|
| 183 |
cfg.fp16 = True
|
| 184 |
cfg.bf16 = False
|
| 185 |
+
|
| 186 |
+
# Store inference mode into cfg when passed via args
|
| 187 |
+
cfg.inference = True if "inference" in kwargs else cfg.get("inference", False)
|
| 188 |
|
| 189 |
# load the tokenizer first
|
| 190 |
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
|
|
|
| 192 |
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
| 193 |
|
| 194 |
if check_not_in(
|
| 195 |
+
["shard", "merge_lora"], kwargs
|
| 196 |
+
) and not cfg.inference: # don't need to load dataset for these
|
| 197 |
train_dataset, eval_dataset = load_prepare_datasets(
|
| 198 |
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
| 199 |
)
|
|
|
|
| 219 |
cfg.model_type,
|
| 220 |
tokenizer,
|
| 221 |
cfg,
|
| 222 |
+
adapter=cfg.adapter
|
|
|
|
| 223 |
)
|
| 224 |
|
| 225 |
if "merge_lora" in kwargs and cfg.adapter is not None:
|
|
|
|
| 232 |
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
| 233 |
return
|
| 234 |
|
| 235 |
+
if cfg.inference:
|
| 236 |
logging.info("calling do_inference function")
|
| 237 |
do_inference(cfg, model, tokenizer)
|
| 238 |
return
|