File size: 213 Bytes
70d1188
 
 
 
 
 
 
1
2
3
4
5
6
7
import torch

def collate(batch):
    if isinstance(batch[0],dict):
        return {k: collate([d[k] for d in batch]) for k in batch[0].keys()}
    else:
        return torch.stack([torch.stack(t) for t in batch])