File size: 1,336 Bytes
c8ddb9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""Custom collate function for the data loader."""

from typing import Any, List

import torch
from torch.nn.utils.rnn import pad_sequence


def custom_collate(batch: List[Any], device: Any) -> Any:
    """
    Custom collate function to be used in the data loader.
    :param batch: list, with length equal to number of batches.
    :return: processed batch of data [add padding to text, stack tensors in batch]
    """
    img, correct_capt, curr_class, word_labels = zip(*batch)
    batched_img = torch.stack(img, dim=0).to(
        device
    )  # shape: (batch_size, 3, height, width)
    correct_capt_len = torch.tensor(
        [len(capt) for capt in correct_capt], dtype=torch.int64
    ).unsqueeze(
        1
    )  # shape: (batch_size, 1)
    batched_correct_capt = pad_sequence(
        correct_capt, batch_first=True, padding_value=0
    ).to(
        device
    )  # shape: (batch_size, max_seq_len)
    batched_curr_class = torch.stack(curr_class, dim=0).to(
        device
    )  # shape: (batch_size, 1)
    batched_word_labels = pad_sequence(
        word_labels, batch_first=True, padding_value=0
    ).to(
        device
    )  # shape: (batch_size, max_seq_len)
    return (
        batched_img,
        batched_correct_capt,
        correct_capt_len,
        batched_curr_class,
        batched_word_labels,
    )