Fix data.py lint
Browse files- src/axolotl/utils/data.py +18 -15
src/axolotl/utils/data.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
from hashlib import md5
|
| 3 |
from pathlib import Path
|
|
@@ -46,12 +48,12 @@ def load_tokenized_prepared_datasets(
|
|
| 46 |
md5(
|
| 47 |
(
|
| 48 |
str(cfg.sequence_len)
|
| 49 |
-
+ "@"
|
| 50 |
-
+ "|".join(
|
| 51 |
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
| 52 |
)
|
| 53 |
-
+ "|"
|
| 54 |
-
+ tokenizer_name
|
| 55 |
).encode("utf-8")
|
| 56 |
).hexdigest()
|
| 57 |
)
|
|
@@ -81,6 +83,7 @@ def load_tokenized_prepared_datasets(
|
|
| 81 |
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
| 82 |
logging.info("Loading raw datasets...")
|
| 83 |
datasets = []
|
|
|
|
| 84 |
for d in cfg.datasets:
|
| 85 |
ds: Union[Dataset, DatasetDict] = None
|
| 86 |
ds_from_hub = False
|
|
@@ -229,7 +232,7 @@ def load_tokenized_prepared_datasets(
|
|
| 229 |
|
| 230 |
samples = []
|
| 231 |
for d in datasets:
|
| 232 |
-
samples = samples +
|
| 233 |
dataset = Dataset.from_list(samples).shuffle(seed=42)
|
| 234 |
if cfg.local_rank == 0:
|
| 235 |
logging.info(
|
|
@@ -265,14 +268,14 @@ def load_prepare_datasets(
|
|
| 265 |
md5(
|
| 266 |
(
|
| 267 |
str(cfg.sequence_len)
|
| 268 |
-
+ "@"
|
| 269 |
-
+ str(max_packed_sequence_len)
|
| 270 |
-
+ seed
|
| 271 |
-
+ "|".join(
|
| 272 |
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
| 273 |
)
|
| 274 |
-
+ "|"
|
| 275 |
-
+ tokenizer_name
|
| 276 |
).encode("utf-8")
|
| 277 |
).hexdigest()
|
| 278 |
)
|
|
@@ -327,7 +330,7 @@ def load_prepare_datasets(
|
|
| 327 |
logging.info(
|
| 328 |
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
|
| 329 |
)
|
| 330 |
-
dataset = Dataset.from_list(
|
| 331 |
|
| 332 |
# filter out bad data
|
| 333 |
dataset = Dataset.from_list(
|
|
@@ -335,9 +338,9 @@ def load_prepare_datasets(
|
|
| 335 |
d
|
| 336 |
for d in dataset
|
| 337 |
if len(d["input_ids"]) < cfg.sequence_len
|
| 338 |
-
and len(d["input_ids"]) > 0
|
| 339 |
-
and len(d["input_ids"]) == len(d["attention_mask"])
|
| 340 |
-
and len(d["input_ids"]) == len(d["labels"])
|
| 341 |
]
|
| 342 |
)
|
| 343 |
|
|
|
|
| 1 |
+
"""Module containing data utilities for Axolotl"""
|
| 2 |
+
|
| 3 |
import logging
|
| 4 |
from hashlib import md5
|
| 5 |
from pathlib import Path
|
|
|
|
| 48 |
md5(
|
| 49 |
(
|
| 50 |
str(cfg.sequence_len)
|
| 51 |
+
+ "@"
|
| 52 |
+
+ "|".join(
|
| 53 |
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
| 54 |
)
|
| 55 |
+
+ "|"
|
| 56 |
+
+ tokenizer_name
|
| 57 |
).encode("utf-8")
|
| 58 |
).hexdigest()
|
| 59 |
)
|
|
|
|
| 83 |
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
| 84 |
logging.info("Loading raw datasets...")
|
| 85 |
datasets = []
|
| 86 |
+
# pylint: disable=invalid-name
|
| 87 |
for d in cfg.datasets:
|
| 88 |
ds: Union[Dataset, DatasetDict] = None
|
| 89 |
ds_from_hub = False
|
|
|
|
| 232 |
|
| 233 |
samples = []
|
| 234 |
for d in datasets:
|
| 235 |
+
samples = samples + list(d)
|
| 236 |
dataset = Dataset.from_list(samples).shuffle(seed=42)
|
| 237 |
if cfg.local_rank == 0:
|
| 238 |
logging.info(
|
|
|
|
| 268 |
md5(
|
| 269 |
(
|
| 270 |
str(cfg.sequence_len)
|
| 271 |
+
+ "@"
|
| 272 |
+
+ str(max_packed_sequence_len)
|
| 273 |
+
+ seed
|
| 274 |
+
+ "|".join(
|
| 275 |
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
| 276 |
)
|
| 277 |
+
+ "|"
|
| 278 |
+
+ tokenizer_name
|
| 279 |
).encode("utf-8")
|
| 280 |
).hexdigest()
|
| 281 |
)
|
|
|
|
| 330 |
logging.info(
|
| 331 |
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
|
| 332 |
)
|
| 333 |
+
dataset = Dataset.from_list(list(constant_len_dataset))
|
| 334 |
|
| 335 |
# filter out bad data
|
| 336 |
dataset = Dataset.from_list(
|
|
|
|
| 338 |
d
|
| 339 |
for d in dataset
|
| 340 |
if len(d["input_ids"]) < cfg.sequence_len
|
| 341 |
+
and len(d["input_ids"]) > 0
|
| 342 |
+
and len(d["input_ids"]) == len(d["attention_mask"])
|
| 343 |
+
and len(d["input_ids"]) == len(d["labels"])
|
| 344 |
]
|
| 345 |
)
|
| 346 |
|