|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Optional |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from tokenizer.tiktoken import IGNORE_INDEX |
|
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
|
|
def padded_collate( |
|
batch: List[Dict[str, List[int]]], |
|
padding_idx: int = 0, |
|
ignore_idx: int = -100, |
|
) -> Dict[str, torch.Tensor]: |
|
"""Pad a batch of sequences to the longest sequence length in the batch, and |
|
convert integer lists to tensors. |
|
|
|
Args: |
|
batch (List[Dict[str, List[int]]]): A list of dictionaries containing input, label pairs. |
|
padding_idx (int): Padding index for input ids. Defaults to 0. |
|
ignore_idx (int): Padding index for labels. Defaults to -100. |
|
|
|
Returns: |
|
Dict[str, torch.Tensor]: Collated input and label tensors. |
|
|
|
Example: |
|
>>> token_pairs = [ |
|
>>> {"input_ids": [1, 2, 3], "labels": [4, 5, 6]}, |
|
>>> {"input_ids": [7,], "labels": [10,]}, |
|
>>> ] |
|
>>> collated = padded_collate( |
|
>>> batch=token_pairs, |
|
>>> padding_idx=padding_idx, |
|
>>> ignore_idx=ignore_idx, |
|
>>> ) |
|
>>> collated["input_ids"] |
|
>>> tensor([[1, 2, 3], [7, 0, 0]]) |
|
>>> collated["labels"] |
|
>>> tensor([[4, 5, 6], [10, -100, -100]]) |
|
""" |
|
input_ids = pad_sequence( |
|
[x["input_ids"] for x in batch], |
|
batch_first=True, |
|
padding_value=padding_idx, |
|
) |
|
labels = pad_sequence( |
|
[x["labels"] for x in batch], |
|
batch_first=True, |
|
padding_value=ignore_idx, |
|
) |
|
|
|
input_ids_seq_len = input_ids.shape[-1] |
|
labels_seq_len = labels.shape[-1] |
|
|
|
|
|
if input_ids_seq_len > labels_seq_len: |
|
labels = F.pad( |
|
labels, (0, input_ids_seq_len - labels_seq_len), value=ignore_idx |
|
) |
|
elif labels_seq_len > input_ids_seq_len: |
|
input_ids = F.pad( |
|
input_ids, |
|
(0, labels_seq_len - input_ids_seq_len), |
|
value=padding_idx, |
|
) |
|
return {"input_ids": input_ids, "labels": labels} |
|
|
|
|
|
|
|
@dataclass |
|
class MultiModalCollator: |
|
padding_idx: int = 128004 |
|
ignore_idx: int = IGNORE_INDEX |
|
pad_max_tiles: Optional[int] = None |
|
pad_max_images: Optional[int] = None |
|
|
|
def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: |
|
"""Pad a batch of text sequences, tiled image tensors, aspect ratios, |
|
and cross attention masks. This can be used for both training and inference. |
|
|
|
``batch`` is expected to be a list of sample dicts containing the following:: |
|
- "input_ids": List[int] of length text_seq_len, varies across samples |
|
- "labels": List[int] of length text_seq_len, varies across samples |
|
- "encoder_input": Dict[str, List[torch.Tensor]] |
|
- "images": List[torch.Tensor], each with shape (n_tiles, c, h, w) |
|
- "aspect_ratio": List[torch.Tensor], each with shape (2, ) to indicate h_ratio, w_ratio |
|
|
|
Shape notation: |
|
- c = channel dim |
|
- h = height dim |
|
- w = weight dim |
|
|
|
Note: |
|
For each element in the batch, ``len(images) == len(aspect_ratio)``. |
|
|
|
This collater does the following: |
|
(1) Pad text sequence and encoder mask to the longest sequence length in the batch |
|
(2) Pad image tensors in the tile dimension with zeros to the largest number |
|
of tiles in the batch |
|
(3) Add empty images of zeros to samples up to max number of images in the batch |
|
(4) Pad aspect ratios with (1,1) for all added padding images |
|
|
|
Args: |
|
batch (List[Dict[str, Any]]): A list of sample dicts containing input_ids, |
|
labels, images, and aspect_ratio. |
|
padding_idx (int): Padding index for input token ids. Defaults to 0. |
|
ignore_idx (int): Padding index for labels. Defaults to -100. |
|
pad_max_tiles (Optional[int]): Maximum number of tiles to pad to. If None, will pad to the largest number of tiles |
|
in the batch. Defaults to None. |
|
pad_max_images (Optional[int]): Maximum number of images to pad to. If None, will pad to the largest number of images |
|
in the batch. Defaults to None. |
|
|
|
Returns: |
|
Dict[str, Tensor]: Collated tokens, labels, images, aspect_ratio tensors. |
|
- tokens: Tensor of shape (bsz, max_seq_len) |
|
- labels: Tensor of shape (bsz, max_seq_len) |
|
- images: Tensor of shape (bsz, max_num_images, max_num_tiles, c, h, w) |
|
- aspect_ratio: Tensor of shape (bsz, max_num_images, 2) |
|
|
|
Example: |
|
>>> image_id = 1 |
|
>>> tokens_per_tile = 5 |
|
>>> c, h, w = 1, 1, 1 |
|
>>> batch = [ |
|
... { |
|
... "input_ids": [1, 2, 1, 3], "labels": [4, 5, 6, 7], |
|
... "encoder_input": { |
|
... # One image with two tiles, one image with three tiles |
|
... "images": [torch.ones(2, c, h, w), torch.ones(3, c, h, w)], |
|
... "aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 3])], |
|
... }, |
|
... }, |
|
... { |
|
... "input_ids": [1, 4], "labels": [8, 9], |
|
... "encoder_input": { |
|
... # One image with four tiles |
|
... "images": [torch.ones(4, c, h, w)], |
|
... "aspect_ratio": [torch.tensor([2, 2])], |
|
... }, |
|
... }, |
|
... ] |
|
... collator = MultiModalCollator(pad_max_tiles=4) |
|
>>> model_inputs = collator(batch=batch) |
|
>>> print(model_inputs["input_ids"]) |
|
tensor([[1, 2, 1, 3], |
|
[1, 4, 0, 0]]) |
|
>>> print(model_inputs["labels"]) |
|
tensor([[4, 5, 6, 7], |
|
[8, 9, -100, -100]]) |
|
>>> print(model_inputs["encoder_input"]["images"].shape) # (bsz, max_num_images, max_num_tiles, c, h, w) |
|
torch.Size([2, 2, 4, 1, 1, 1]) |
|
>>> print(model_inputs["encoder_input"]["aspect_ratio"].shape) # (bsz, max_num_images, 2) |
|
torch.Size([2, 2, 2]) |
|
>>> print(model_inputs["encoder_input"]["images"][0, 0, ...]) # Image with two tiles got padded to four |
|
tensor([[[[1.]]], [[[1.]]], [[[0.]]], [[[0.]]]]) |
|
>>> print(model_inputs["encoder_input"]["images"][0, 1, ...]) # Image with three tiles got padded to four |
|
tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[0.]]]]) |
|
>>> print(model_inputs["encoder_input"]["images"][1, 0, ...]) # Image with four tiles did not get padded |
|
tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[1.]]]]) |
|
>>> print(model_inputs["encoder_input"]["images"][1, 1, ...]) # Extra padding image was added to second sample |
|
tensor([[[[0.]]], [[[0.]]], [[[0.]]], [[[0.]]]]) |
|
""" |
|
|
|
text_only = [ |
|
{"input_ids": sample["input_ids"], "labels": sample["labels"]} |
|
for sample in batch |
|
] |
|
collated_text = padded_collate(text_only, self.padding_idx, self.ignore_idx) |
|
|
|
if self.pad_max_tiles is None: |
|
|
|
max_num_tiles = max(sample["images_tiles"].shape[0] for sample in batch) |
|
else: |
|
max_num_tiles = self.pad_max_tiles |
|
|
|
|
|
batch_images = [] |
|
batch_aspect_ratios = [] |
|
|
|
for sample in batch: |
|
sample_images = [] |
|
for image in sample["encoder_input"]["images"]: |
|
|
|
n_tiles = image.shape[0] |
|
|
|
|
|
padding_tiles = max_num_tiles - n_tiles |
|
|
|
|
|
padded_image = F.pad( |
|
image, (0, 0, 0, 0, 0, 0, 0, padding_tiles), value=0 |
|
) |
|
|
|
sample_images.append(padded_image) |
|
|
|
batch_images.append(torch.stack(sample_images)) |
|
batch_aspect_ratios.append( |
|
torch.stack(sample["encoder_input"]["aspect_ratio"]) |
|
) |
|
|
|
|
|
collated_images = pad_sequence(batch_images, batch_first=True, padding_value=0) |
|
|
|
collated_aspect_ratios = pad_sequence( |
|
batch_aspect_ratios, batch_first=True, padding_value=1 |
|
) |
|
|
|
batch_dict = { |
|
"input_ids": collated_text["input_ids"], |
|
"labels": collated_text["labels"], |
|
"encoder_input": { |
|
"images": collated_images, |
|
"aspect_ratio": collated_aspect_ratios, |
|
}, |
|
} |
|
|
|
return batch_dict |
|
|