rayst3r / utils /collate.py
bartduis's picture
init
70d1188
raw
history blame
213 Bytes
import torch
def collate(batch):
if isinstance(batch[0],dict):
return {k: collate([d[k] for d in batch]) for k in batch[0].keys()}
else:
return torch.stack([torch.stack(t) for t in batch])