update readme and add typehints
Browse files- README.md +1 -7
- src/axolotl/utils/data.py +8 -7
README.md
CHANGED
|
@@ -363,13 +363,7 @@ Pass the appropriate flag to the train command:
|
|
| 363 |
|
| 364 |
### Merge LORA to base
|
| 365 |
|
| 366 |
-
Add below flag to train command above
|
| 367 |
-
|
| 368 |
-
```bash
|
| 369 |
-
--merge_lora --lora_model_dir="./completed-model"
|
| 370 |
-
```
|
| 371 |
-
|
| 372 |
-
Add below flag to train command above (and using QLoRA)
|
| 373 |
|
| 374 |
```bash
|
| 375 |
--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
|
|
|
| 363 |
|
| 364 |
### Merge LORA to base
|
| 365 |
|
| 366 |
+
Add below flag to train command above
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
|
| 368 |
```bash
|
| 369 |
--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
src/axolotl/utils/data.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import logging
|
| 2 |
from hashlib import md5
|
| 3 |
from pathlib import Path
|
|
|
|
| 4 |
|
| 5 |
from datasets import (
|
| 6 |
load_from_disk,
|
|
@@ -80,7 +81,7 @@ def load_tokenized_prepared_datasets(
|
|
| 80 |
logging.info("Loading raw datasets...")
|
| 81 |
datasets = []
|
| 82 |
for d in cfg.datasets:
|
| 83 |
-
ds = None
|
| 84 |
ds_from_hub = False
|
| 85 |
try:
|
| 86 |
load_dataset(d.path, streaming=True, use_auth_token=True)
|
|
@@ -90,32 +91,32 @@ def load_tokenized_prepared_datasets(
|
|
| 90 |
|
| 91 |
# prefer local dataset, even if hub exists
|
| 92 |
if Path(d.path).exists():
|
| 93 |
-
ds:
|
| 94 |
"json", data_files=d.path, streaming=False, split=None
|
| 95 |
)
|
| 96 |
elif ds_from_hub:
|
| 97 |
if d.data_files:
|
| 98 |
-
ds = load_dataset(
|
| 99 |
d.path,
|
| 100 |
streaming=False,
|
| 101 |
data_files=d.data_files,
|
| 102 |
use_auth_token=True,
|
| 103 |
)
|
| 104 |
else:
|
| 105 |
-
ds = load_dataset(d.path, streaming=False, use_auth_token=True)
|
| 106 |
else:
|
| 107 |
fp = hf_hub_download(
|
| 108 |
repo_id=d.path, repo_type="dataset", filename=d.data_files
|
| 109 |
)
|
| 110 |
-
ds = load_dataset("json", data_files=fp, streaming=False, split=None)
|
| 111 |
if not ds:
|
| 112 |
raise Exception("unhandled dataset load")
|
| 113 |
# support for using a subset of the data
|
| 114 |
if d.shards:
|
| 115 |
if "train" in ds:
|
| 116 |
-
ds = ds.shuffle(seed=42)["train"].shard(num_shards=
|
| 117 |
else:
|
| 118 |
-
ds = ds.shuffle(seed=42).shard(num_shards=
|
| 119 |
d_type = d.type
|
| 120 |
d_type_split = d_type.split(":")
|
| 121 |
d_base_type = d_type_split[0]
|
|
|
|
| 1 |
import logging
|
| 2 |
from hashlib import md5
|
| 3 |
from pathlib import Path
|
| 4 |
+
from typing import Union
|
| 5 |
|
| 6 |
from datasets import (
|
| 7 |
load_from_disk,
|
|
|
|
| 81 |
logging.info("Loading raw datasets...")
|
| 82 |
datasets = []
|
| 83 |
for d in cfg.datasets:
|
| 84 |
+
ds: Union[Dataset, DatasetDict] = None
|
| 85 |
ds_from_hub = False
|
| 86 |
try:
|
| 87 |
load_dataset(d.path, streaming=True, use_auth_token=True)
|
|
|
|
| 91 |
|
| 92 |
# prefer local dataset, even if hub exists
|
| 93 |
if Path(d.path).exists():
|
| 94 |
+
ds: Dataset = load_dataset(
|
| 95 |
"json", data_files=d.path, streaming=False, split=None
|
| 96 |
)
|
| 97 |
elif ds_from_hub:
|
| 98 |
if d.data_files:
|
| 99 |
+
ds: Dataset = load_dataset(
|
| 100 |
d.path,
|
| 101 |
streaming=False,
|
| 102 |
data_files=d.data_files,
|
| 103 |
use_auth_token=True,
|
| 104 |
)
|
| 105 |
else:
|
| 106 |
+
ds: Dataset = load_dataset(d.path, streaming=False, use_auth_token=True)
|
| 107 |
else:
|
| 108 |
fp = hf_hub_download(
|
| 109 |
repo_id=d.path, repo_type="dataset", filename=d.data_files
|
| 110 |
)
|
| 111 |
+
ds: Dataset = load_dataset("json", data_files=fp, streaming=False, split=None)
|
| 112 |
if not ds:
|
| 113 |
raise Exception("unhandled dataset load")
|
| 114 |
# support for using a subset of the data
|
| 115 |
if d.shards:
|
| 116 |
if "train" in ds:
|
| 117 |
+
ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
|
| 118 |
else:
|
| 119 |
+
ds: Dataset = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0)
|
| 120 |
d_type = d.type
|
| 121 |
d_type_split = d_type.split(":")
|
| 122 |
d_base_type = d_type_split[0]
|