|
import torch |
|
import json |
|
|
|
|
|
def get_train_sampler(dataset, rank, world_size, global_batch_size, max_steps, |
|
resume_step, seed): |
|
sample_indices = torch.empty([max_steps * global_batch_size // world_size], |
|
dtype=torch.long) |
|
epoch_id, fill_ptr, offs = 0, 0, 0 |
|
while fill_ptr < sample_indices.size(0): |
|
g = torch.Generator() |
|
g.manual_seed(seed + epoch_id) |
|
epoch_sample_indices = torch.randperm(len(dataset), generator=g) |
|
epoch_id += 1 |
|
epoch_sample_indices = epoch_sample_indices[ |
|
(rank + offs) % world_size::world_size |
|
] |
|
offs = (offs + world_size - len(dataset) % world_size) % world_size |
|
epoch_sample_indices = epoch_sample_indices[ |
|
:sample_indices.size(0) - fill_ptr |
|
] |
|
sample_indices[fill_ptr: fill_ptr + epoch_sample_indices.size(0)] = \ |
|
epoch_sample_indices |
|
fill_ptr += epoch_sample_indices.size(0) |
|
return sample_indices[resume_step * global_batch_size // world_size:].tolist() |
|
|
|
|
|
|
|
|
|
def get_packed_batch_sampler( |
|
dataset, rank, world_size, max_steps, resume_step, seed |
|
): |
|
sample_indices = [None for _ in range(max_steps)] |
|
epoch_id, fill_ptr, offs = 0, 0, 0 |
|
while fill_ptr < len(sample_indices): |
|
g = torch.Generator() |
|
g.manual_seed(seed + epoch_id) |
|
epoch_sample_indices = torch.randperm(len(dataset), generator=g) |
|
epoch_id += 1 |
|
epoch_sample_indices = epoch_sample_indices[ |
|
(rank + offs) % world_size::world_size |
|
] |
|
offs = (offs + world_size - len(dataset) % world_size) % world_size |
|
epoch_sample_indices = epoch_sample_indices[ |
|
:len(sample_indices) - fill_ptr |
|
] |
|
sample_indices[fill_ptr: fill_ptr + epoch_sample_indices.size(0)] = [ |
|
dataset[i] for i in epoch_sample_indices |
|
] |
|
fill_ptr += epoch_sample_indices.size(0) |
|
return sample_indices[resume_step:] |
|
|
|
|