Spaces:
Running
Running
feat(data): support accumulation in non-streaming
Browse files- src/dalle_mini/data.py +10 -2
src/dalle_mini/data.py
CHANGED
@@ -161,13 +161,16 @@ class Dataset:
|
|
161 |
def _dataloader_datasets_non_streaming(
|
162 |
dataset: Dataset,
|
163 |
per_device_batch_size: int,
|
|
|
164 |
rng: jax.random.PRNGKey = None,
|
165 |
):
|
166 |
"""
|
167 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
168 |
Shuffle batches if rng is set.
|
169 |
"""
|
170 |
-
batch_size =
|
|
|
|
|
171 |
steps_per_epoch = len(dataset) // batch_size
|
172 |
|
173 |
if rng is not None:
|
@@ -183,6 +186,11 @@ class Dataset:
|
|
183 |
for idx in batch_idx:
|
184 |
batch = dataset[idx]
|
185 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
|
|
|
|
|
|
|
|
|
|
186 |
batch = shard(batch)
|
187 |
yield batch
|
188 |
|
@@ -244,7 +252,7 @@ class Dataset:
|
|
244 |
if split == "train":
|
245 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
246 |
return _dataloader_datasets_non_streaming(
|
247 |
-
ds, per_device_batch_size, input_rng
|
248 |
)
|
249 |
|
250 |
@property
|
|
|
161 |
def _dataloader_datasets_non_streaming(
|
162 |
dataset: Dataset,
|
163 |
per_device_batch_size: int,
|
164 |
+
gradient_accumulation_steps: int,
|
165 |
rng: jax.random.PRNGKey = None,
|
166 |
):
|
167 |
"""
|
168 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
169 |
Shuffle batches if rng is set.
|
170 |
"""
|
171 |
+
batch_size = (
|
172 |
+
per_device_batch_size * num_devices * gradient_accumulation_steps
|
173 |
+
)
|
174 |
steps_per_epoch = len(dataset) // batch_size
|
175 |
|
176 |
if rng is not None:
|
|
|
186 |
for idx in batch_idx:
|
187 |
batch = dataset[idx]
|
188 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
189 |
+
if gradient_accumulation_steps is not None:
|
190 |
+
batch = jax.tree_map(
|
191 |
+
lambda x: x.reshape((-1, per_device_batch_size) + x.shape[1:]),
|
192 |
+
batch,
|
193 |
+
)
|
194 |
batch = shard(batch)
|
195 |
yield batch
|
196 |
|
|
|
252 |
if split == "train":
|
253 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
254 |
return _dataloader_datasets_non_streaming(
|
255 |
+
ds, per_device_batch_size, gradient_accumulation_steps, input_rng
|
256 |
)
|
257 |
|
258 |
@property
|