chore: Refactor inf_kwargs out
Browse files- scripts/finetune.py +5 -5
scripts/finetune.py
CHANGED
|
@@ -63,7 +63,7 @@ def get_multi_line_input() -> Optional[str]:
|
|
| 63 |
return instruction
|
| 64 |
|
| 65 |
|
| 66 |
-
def do_inference(cfg, model, tokenizer, prompter
|
| 67 |
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
| 68 |
|
| 69 |
for token, symbol in default_tokens.items():
|
|
@@ -257,13 +257,13 @@ def train(
|
|
| 257 |
|
| 258 |
if cfg.inference:
|
| 259 |
logging.info("calling do_inference function")
|
| 260 |
-
|
| 261 |
if "prompter" in kwargs:
|
| 262 |
if kwargs["prompter"] == "None":
|
| 263 |
-
|
| 264 |
else:
|
| 265 |
-
|
| 266 |
-
do_inference(cfg, model, tokenizer,
|
| 267 |
return
|
| 268 |
|
| 269 |
if "shard" in kwargs:
|
|
|
|
| 63 |
return instruction
|
| 64 |
|
| 65 |
|
| 66 |
+
def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
|
| 67 |
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
| 68 |
|
| 69 |
for token, symbol in default_tokens.items():
|
|
|
|
| 257 |
|
| 258 |
if cfg.inference:
|
| 259 |
logging.info("calling do_inference function")
|
| 260 |
+
prompter: Optional[str] = "AlpacaPrompter"
|
| 261 |
if "prompter" in kwargs:
|
| 262 |
if kwargs["prompter"] == "None":
|
| 263 |
+
prompter = None
|
| 264 |
else:
|
| 265 |
+
prompter = kwargs["prompter"]
|
| 266 |
+
do_inference(cfg, model, tokenizer, prompter=prompter)
|
| 267 |
return
|
| 268 |
|
| 269 |
if "shard" in kwargs:
|