fix dataset handling, support galactica
Browse files- configs/galactica_1_3B.yml +41 -0
- src/axolotl/utils/data.py +14 -11
- src/axolotl/utils/models.py +4 -0
configs/galactica_1_3B.yml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
base_model: facebook/galactica-1.3b
|
| 2 |
+
model_type: AutoModelForCausalLM
|
| 3 |
+
tokenizer_type: AutoTokenizer
|
| 4 |
+
load_in_8bit: false
|
| 5 |
+
datasets:
|
| 6 |
+
- path: tatsu-lab/alpaca
|
| 7 |
+
type: alpaca
|
| 8 |
+
dataset_prepared_path: last_run_prepared
|
| 9 |
+
val_set_size: 0.1
|
| 10 |
+
adapter:
|
| 11 |
+
lora_model_dir:
|
| 12 |
+
sequence_len: 1024
|
| 13 |
+
max_packed_sequence_len: 1024
|
| 14 |
+
lora_r: 8
|
| 15 |
+
lora_alpha: 16
|
| 16 |
+
lora_dropout: 0.05
|
| 17 |
+
lora_target_modules:
|
| 18 |
+
- q_proj
|
| 19 |
+
- v_proj
|
| 20 |
+
lora_fan_in_fan_out: false
|
| 21 |
+
wandb_project:
|
| 22 |
+
wandb_watch:
|
| 23 |
+
wandb_run_id:
|
| 24 |
+
wandb_log_model: checkpoint
|
| 25 |
+
output_dir: ./lora-llama-alpaca
|
| 26 |
+
batch_size: 32
|
| 27 |
+
micro_batch_size: 16
|
| 28 |
+
num_epochs: 3
|
| 29 |
+
learning_rate: 0.00003
|
| 30 |
+
train_on_inputs: false
|
| 31 |
+
group_by_length: false
|
| 32 |
+
bf16: false
|
| 33 |
+
tf32: false
|
| 34 |
+
early_stopping_patience:
|
| 35 |
+
resume_from_checkpoint:
|
| 36 |
+
local_rank:
|
| 37 |
+
special_tokens:
|
| 38 |
+
pad_token: "[PAD]"
|
| 39 |
+
bos_token: "<s>"
|
| 40 |
+
eos_token: "</s>"
|
| 41 |
+
unk_token: "<unk>"
|
src/axolotl/utils/data.py
CHANGED
|
@@ -31,7 +31,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
| 31 |
ds_hash = str(
|
| 32 |
md5(
|
| 33 |
(
|
| 34 |
-
str(
|
| 35 |
+ "@"
|
| 36 |
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
| 37 |
).encode("utf-8")
|
|
@@ -114,21 +114,24 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
| 114 |
datasets.append(ds_wrapper)
|
| 115 |
else:
|
| 116 |
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
| 117 |
-
logging.info("merging and shuffling master dataset")
|
| 118 |
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
| 120 |
if cfg.local_rank == 0:
|
| 121 |
logging.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
| 122 |
dataset.save_to_disk(prepared_ds_path)
|
| 123 |
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
|
| 133 |
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
| 134 |
logging.info(f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards")
|
|
|
|
| 31 |
ds_hash = str(
|
| 32 |
md5(
|
| 33 |
(
|
| 34 |
+
str(cfg.sequence_len)
|
| 35 |
+ "@"
|
| 36 |
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
| 37 |
).encode("utf-8")
|
|
|
|
| 114 |
datasets.append(ds_wrapper)
|
| 115 |
else:
|
| 116 |
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
| 117 |
+
logging.info("tokenizing, merging, and shuffling master dataset")
|
| 118 |
|
| 119 |
+
samples = []
|
| 120 |
+
for d in datasets:
|
| 121 |
+
samples = samples + [i for i in d]
|
| 122 |
+
dataset = Dataset.from_list(samples).shuffle(seed=42)
|
| 123 |
if cfg.local_rank == 0:
|
| 124 |
logging.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
| 125 |
dataset.save_to_disk(prepared_ds_path)
|
| 126 |
|
| 127 |
+
if cfg.max_packed_sequence_len is not None:
|
| 128 |
+
constant_len_dataset = ConstantLengthDataset(
|
| 129 |
+
tokenizer,
|
| 130 |
+
[dataset],
|
| 131 |
+
seq_length=max_packed_sequence_len,
|
| 132 |
+
)
|
| 133 |
+
logging.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}")
|
| 134 |
+
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
|
| 135 |
|
| 136 |
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
| 137 |
logging.info(f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards")
|
src/axolotl/utils/models.py
CHANGED
|
@@ -161,6 +161,10 @@ def load_model(
|
|
| 161 |
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
| 162 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
if load_in_8bit and not cfg.load_4bit:
|
| 165 |
logging.info("converting model w/ prepare_model_for_int8_training")
|
| 166 |
model = prepare_model_for_int8_training(model)
|
|
|
|
| 161 |
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
| 162 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 163 |
|
| 164 |
+
if cfg.special_tokens:
|
| 165 |
+
for k, v in cfg.special_tokens.items():
|
| 166 |
+
setattr(tokenizer, k, v)
|
| 167 |
+
|
| 168 |
if load_in_8bit and not cfg.load_4bit:
|
| 169 |
logging.info("converting model w/ prepare_model_for_int8_training")
|
| 170 |
model = prepare_model_for_int8_training(model)
|