Merge pull request #183 from OpenAccess-AI-Collective/inference-from-stdin
Browse files- README.md +5 -0
- scripts/finetune.py +18 -5
README.md
CHANGED
|
@@ -495,6 +495,11 @@ Pass the appropriate flag to the train command:
|
|
| 495 |
```bash
|
| 496 |
--inference --base_model ./completed-model
|
| 497 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
|
| 499 |
### Merge LORA to base
|
| 500 |
|
|
|
|
| 495 |
```bash
|
| 496 |
--inference --base_model ./completed-model
|
| 497 |
```
|
| 498 |
+
- Full weights finetune w/ a prompt from a text file:
|
| 499 |
+
```bash
|
| 500 |
+
cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
|
| 501 |
+
--base_model ./completed-model --inference --prompter=None --load_in_8bit=True
|
| 502 |
+
```
|
| 503 |
|
| 504 |
### Merge LORA to base
|
| 505 |
|
scripts/finetune.py
CHANGED
|
@@ -71,7 +71,11 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
|
| 71 |
if not (cfg.special_tokens and token in cfg.special_tokens):
|
| 72 |
tokenizer.add_special_tokens({token: symbol})
|
| 73 |
|
| 74 |
-
prompter_module =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
while True:
|
| 77 |
print("=" * 80)
|
|
@@ -79,9 +83,12 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
|
| 79 |
instruction = get_multi_line_input()
|
| 80 |
if not instruction:
|
| 81 |
return
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
| 85 |
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
| 86 |
print("=" * 40)
|
| 87 |
model.eval()
|
|
@@ -242,7 +249,13 @@ def train(
|
|
| 242 |
|
| 243 |
if "inference" in kwargs:
|
| 244 |
logging.info("calling do_inference function")
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
return
|
| 247 |
|
| 248 |
if "shard" in kwargs:
|
|
|
|
| 71 |
if not (cfg.special_tokens and token in cfg.special_tokens):
|
| 72 |
tokenizer.add_special_tokens({token: symbol})
|
| 73 |
|
| 74 |
+
prompter_module = None
|
| 75 |
+
if prompter:
|
| 76 |
+
prompter_module = getattr(
|
| 77 |
+
importlib.import_module("axolotl.prompters"), prompter
|
| 78 |
+
)
|
| 79 |
|
| 80 |
while True:
|
| 81 |
print("=" * 80)
|
|
|
|
| 83 |
instruction = get_multi_line_input()
|
| 84 |
if not instruction:
|
| 85 |
return
|
| 86 |
+
if prompter_module:
|
| 87 |
+
prompt: str = next(
|
| 88 |
+
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
| 89 |
+
)
|
| 90 |
+
else:
|
| 91 |
+
prompt = instruction.strip()
|
| 92 |
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
| 93 |
print("=" * 40)
|
| 94 |
model.eval()
|
|
|
|
| 249 |
|
| 250 |
if "inference" in kwargs:
|
| 251 |
logging.info("calling do_inference function")
|
| 252 |
+
inf_kwargs: Dict[str, Any] = {}
|
| 253 |
+
if "prompter" in kwargs:
|
| 254 |
+
if kwargs["prompter"] == "None":
|
| 255 |
+
inf_kwargs["prompter"] = None
|
| 256 |
+
else:
|
| 257 |
+
inf_kwargs["prompter"] = kwargs["prompter"]
|
| 258 |
+
do_inference(cfg, model, tokenizer, **inf_kwargs)
|
| 259 |
return
|
| 260 |
|
| 261 |
if "shard" in kwargs:
|