Spaces:
Running
Running
fix(train): handle seed_dataset
Browse files- src/dalle_mini/data.py +3 -8
- tools/train/train.py +2 -2
src/dalle_mini/data.py
CHANGED
|
@@ -161,7 +161,7 @@ class Dataset:
|
|
| 161 |
):
|
| 162 |
"""
|
| 163 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
| 164 |
-
Shuffle batches if
|
| 165 |
"""
|
| 166 |
steps_per_epoch = len(dataset) // batch_size
|
| 167 |
|
|
@@ -184,17 +184,13 @@ class Dataset:
|
|
| 184 |
def _dataloader_datasets_streaming(
|
| 185 |
dataset: Dataset, batch_size: int, epoch: int
|
| 186 |
):
|
| 187 |
-
# epoch is only use for multi-host
|
| 188 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
| 189 |
batch = {k: [] for k in keys}
|
| 190 |
first_loop = True
|
| 191 |
while self.multi_hosts or first_loop:
|
| 192 |
# in multi-host, we run forever (no epoch) as hosts need to stop
|
| 193 |
# at the same time and we don't know how much data is on each host
|
| 194 |
-
|
| 195 |
-
# multi-host setting, we reshuffle shards
|
| 196 |
-
epoch += 1
|
| 197 |
-
dataset.set_epoch(epoch)
|
| 198 |
for item in dataset:
|
| 199 |
for k, v in item.items():
|
| 200 |
batch[k].append(v)
|
|
@@ -203,6 +199,7 @@ class Dataset:
|
|
| 203 |
batch = shard(batch)
|
| 204 |
yield batch
|
| 205 |
batch = {k: [] for k in keys}
|
|
|
|
| 206 |
first_loop = False
|
| 207 |
|
| 208 |
if split == "train":
|
|
@@ -213,8 +210,6 @@ class Dataset:
|
|
| 213 |
raise ValueError(f'split must be "train" or "eval", got {split}')
|
| 214 |
|
| 215 |
if self.streaming:
|
| 216 |
-
if split == "train":
|
| 217 |
-
ds.set_epoch(epoch)
|
| 218 |
return _dataloader_datasets_streaming(ds, batch_size, epoch)
|
| 219 |
else:
|
| 220 |
if split == "train":
|
|
|
|
| 161 |
):
|
| 162 |
"""
|
| 163 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
| 164 |
+
Shuffle batches if rng is set.
|
| 165 |
"""
|
| 166 |
steps_per_epoch = len(dataset) // batch_size
|
| 167 |
|
|
|
|
| 184 |
def _dataloader_datasets_streaming(
|
| 185 |
dataset: Dataset, batch_size: int, epoch: int
|
| 186 |
):
|
|
|
|
| 187 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
| 188 |
batch = {k: [] for k in keys}
|
| 189 |
first_loop = True
|
| 190 |
while self.multi_hosts or first_loop:
|
| 191 |
# in multi-host, we run forever (no epoch) as hosts need to stop
|
| 192 |
# at the same time and we don't know how much data is on each host
|
| 193 |
+
dataset.set_epoch(epoch) # reshuffle data at each epoch
|
|
|
|
|
|
|
|
|
|
| 194 |
for item in dataset:
|
| 195 |
for k, v in item.items():
|
| 196 |
batch[k].append(v)
|
|
|
|
| 199 |
batch = shard(batch)
|
| 200 |
yield batch
|
| 201 |
batch = {k: [] for k in keys}
|
| 202 |
+
epoch += 1
|
| 203 |
first_loop = False
|
| 204 |
|
| 205 |
if split == "train":
|
|
|
|
| 210 |
raise ValueError(f'split must be "train" or "eval", got {split}')
|
| 211 |
|
| 212 |
if self.streaming:
|
|
|
|
|
|
|
| 213 |
return _dataloader_datasets_streaming(ds, batch_size, epoch)
|
| 214 |
else:
|
| 215 |
if split == "train":
|
tools/train/train.py
CHANGED
|
@@ -241,7 +241,7 @@ class TrainingArguments:
|
|
| 241 |
)
|
| 242 |
optim_quantized: bool = field(
|
| 243 |
default=False,
|
| 244 |
-
|
| 245 |
"help": "Whether to quantize optimizer (only supported with distributed_shampoo)."
|
| 246 |
},
|
| 247 |
)
|
|
@@ -845,7 +845,7 @@ def main():
|
|
| 845 |
metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
|
| 846 |
|
| 847 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 848 |
-
train_loader = dataset.dataloader("train", train_batch_size)
|
| 849 |
# train
|
| 850 |
for batch in tqdm(
|
| 851 |
train_loader,
|
|
|
|
| 241 |
)
|
| 242 |
optim_quantized: bool = field(
|
| 243 |
default=False,
|
| 244 |
+
metadata={
|
| 245 |
"help": "Whether to quantize optimizer (only supported with distributed_shampoo)."
|
| 246 |
},
|
| 247 |
)
|
|
|
|
| 845 |
metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
|
| 846 |
|
| 847 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 848 |
+
train_loader = dataset.dataloader("train", train_batch_size, epoch)
|
| 849 |
# train
|
| 850 |
for batch in tqdm(
|
| 851 |
train_loader,
|