File size: 1,336 Bytes
c8ddb9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
"""Custom collate function for the data loader."""
from typing import Any, List
import torch
from torch.nn.utils.rnn import pad_sequence
def custom_collate(batch: List[Any], device: Any) -> Any:
"""
Custom collate function to be used in the data loader.
:param batch: list, with length equal to number of batches.
:return: processed batch of data [add padding to text, stack tensors in batch]
"""
img, correct_capt, curr_class, word_labels = zip(*batch)
batched_img = torch.stack(img, dim=0).to(
device
) # shape: (batch_size, 3, height, width)
correct_capt_len = torch.tensor(
[len(capt) for capt in correct_capt], dtype=torch.int64
).unsqueeze(
1
) # shape: (batch_size, 1)
batched_correct_capt = pad_sequence(
correct_capt, batch_first=True, padding_value=0
).to(
device
) # shape: (batch_size, max_seq_len)
batched_curr_class = torch.stack(curr_class, dim=0).to(
device
) # shape: (batch_size, 1)
batched_word_labels = pad_sequence(
word_labels, batch_first=True, padding_value=0
).to(
device
) # shape: (batch_size, max_seq_len)
return (
batched_img,
batched_correct_capt,
correct_capt_len,
batched_curr_class,
batched_word_labels,
)
|