zaydzuhri's picture
Add files using upload-large-folder tool
e49db55 verified
raw
history blame
9.77 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import torch
import torch.nn.functional as F
from tokenizer.tiktoken import IGNORE_INDEX
from torch.nn.utils.rnn import pad_sequence
def padded_collate(
batch: List[Dict[str, List[int]]],
padding_idx: int = 0,
ignore_idx: int = -100,
) -> Dict[str, torch.Tensor]:
"""Pad a batch of sequences to the longest sequence length in the batch, and
convert integer lists to tensors.
Args:
batch (List[Dict[str, List[int]]]): A list of dictionaries containing input, label pairs.
padding_idx (int): Padding index for input ids. Defaults to 0.
ignore_idx (int): Padding index for labels. Defaults to -100.
Returns:
Dict[str, torch.Tensor]: Collated input and label tensors.
Example:
>>> token_pairs = [
>>> {"input_ids": [1, 2, 3], "labels": [4, 5, 6]},
>>> {"input_ids": [7,], "labels": [10,]},
>>> ]
>>> collated = padded_collate(
>>> batch=token_pairs,
>>> padding_idx=padding_idx,
>>> ignore_idx=ignore_idx,
>>> )
>>> collated["input_ids"]
>>> tensor([[1, 2, 3], [7, 0, 0]])
>>> collated["labels"]
>>> tensor([[4, 5, 6], [10, -100, -100]])
"""
input_ids = pad_sequence(
[x["input_ids"] for x in batch],
batch_first=True,
padding_value=padding_idx,
)
labels = pad_sequence(
[x["labels"] for x in batch],
batch_first=True,
padding_value=ignore_idx,
)
input_ids_seq_len = input_ids.shape[-1]
labels_seq_len = labels.shape[-1]
# Hack to pad correctly and not use max_seq_len, which is costly
if input_ids_seq_len > labels_seq_len:
labels = F.pad(
labels, (0, input_ids_seq_len - labels_seq_len), value=ignore_idx
)
elif labels_seq_len > input_ids_seq_len:
input_ids = F.pad(
input_ids,
(0, labels_seq_len - input_ids_seq_len),
value=padding_idx,
)
return {"input_ids": input_ids, "labels": labels}
# NOTE Inspired from torchtune.data._collate.py
@dataclass
class MultiModalCollator:
padding_idx: int = 128004
ignore_idx: int = IGNORE_INDEX
pad_max_tiles: Optional[int] = None
pad_max_images: Optional[int] = None
def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
"""Pad a batch of text sequences, tiled image tensors, aspect ratios,
and cross attention masks. This can be used for both training and inference.
``batch`` is expected to be a list of sample dicts containing the following::
- "input_ids": List[int] of length text_seq_len, varies across samples
- "labels": List[int] of length text_seq_len, varies across samples
- "encoder_input": Dict[str, List[torch.Tensor]]
- "images": List[torch.Tensor], each with shape (n_tiles, c, h, w)
- "aspect_ratio": List[torch.Tensor], each with shape (2, ) to indicate h_ratio, w_ratio
Shape notation:
- c = channel dim
- h = height dim
- w = weight dim
Note:
For each element in the batch, ``len(images) == len(aspect_ratio)``.
This collater does the following:
(1) Pad text sequence and encoder mask to the longest sequence length in the batch
(2) Pad image tensors in the tile dimension with zeros to the largest number
of tiles in the batch
(3) Add empty images of zeros to samples up to max number of images in the batch
(4) Pad aspect ratios with (1,1) for all added padding images
Args:
batch (List[Dict[str, Any]]): A list of sample dicts containing input_ids,
labels, images, and aspect_ratio.
padding_idx (int): Padding index for input token ids. Defaults to 0.
ignore_idx (int): Padding index for labels. Defaults to -100.
pad_max_tiles (Optional[int]): Maximum number of tiles to pad to. If None, will pad to the largest number of tiles
in the batch. Defaults to None.
pad_max_images (Optional[int]): Maximum number of images to pad to. If None, will pad to the largest number of images
in the batch. Defaults to None.
Returns:
Dict[str, Tensor]: Collated tokens, labels, images, aspect_ratio tensors.
- tokens: Tensor of shape (bsz, max_seq_len)
- labels: Tensor of shape (bsz, max_seq_len)
- images: Tensor of shape (bsz, max_num_images, max_num_tiles, c, h, w)
- aspect_ratio: Tensor of shape (bsz, max_num_images, 2)
Example:
>>> image_id = 1
>>> tokens_per_tile = 5
>>> c, h, w = 1, 1, 1
>>> batch = [
... {
... "input_ids": [1, 2, 1, 3], "labels": [4, 5, 6, 7],
... "encoder_input": {
... # One image with two tiles, one image with three tiles
... "images": [torch.ones(2, c, h, w), torch.ones(3, c, h, w)],
... "aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 3])],
... },
... },
... {
... "input_ids": [1, 4], "labels": [8, 9],
... "encoder_input": {
... # One image with four tiles
... "images": [torch.ones(4, c, h, w)],
... "aspect_ratio": [torch.tensor([2, 2])],
... },
... },
... ]
... collator = MultiModalCollator(pad_max_tiles=4)
>>> model_inputs = collator(batch=batch)
>>> print(model_inputs["input_ids"])
tensor([[1, 2, 1, 3],
[1, 4, 0, 0]])
>>> print(model_inputs["labels"])
tensor([[4, 5, 6, 7],
[8, 9, -100, -100]])
>>> print(model_inputs["encoder_input"]["images"].shape) # (bsz, max_num_images, max_num_tiles, c, h, w)
torch.Size([2, 2, 4, 1, 1, 1])
>>> print(model_inputs["encoder_input"]["aspect_ratio"].shape) # (bsz, max_num_images, 2)
torch.Size([2, 2, 2])
>>> print(model_inputs["encoder_input"]["images"][0, 0, ...]) # Image with two tiles got padded to four
tensor([[[[1.]]], [[[1.]]], [[[0.]]], [[[0.]]]])
>>> print(model_inputs["encoder_input"]["images"][0, 1, ...]) # Image with three tiles got padded to four
tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[0.]]]])
>>> print(model_inputs["encoder_input"]["images"][1, 0, ...]) # Image with four tiles did not get padded
tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[1.]]]])
>>> print(model_inputs["encoder_input"]["images"][1, 1, ...]) # Extra padding image was added to second sample
tensor([[[[0.]]], [[[0.]]], [[[0.]]], [[[0.]]]])
"""
# Text tokens can be handled independently by existing collaters
text_only = [
{"input_ids": sample["input_ids"], "labels": sample["labels"]}
for sample in batch
]
collated_text = padded_collate(text_only, self.padding_idx, self.ignore_idx)
if self.pad_max_tiles is None:
# Get max number of tiles in batch
max_num_tiles = max(sample["images_tiles"].shape[0] for sample in batch)
else:
max_num_tiles = self.pad_max_tiles
# Pad images and aspect ratios to max number of tiles
batch_images = []
batch_aspect_ratios = []
for sample in batch:
sample_images = []
for image in sample["encoder_input"]["images"]:
# Single image in each sample has shape (n_tiles, c, h, w)
n_tiles = image.shape[0]
# Single mask in each sample corresponds to a single image and has shape (text_seq_len, image_seq_len)
# where image_seq_len = n_tiles * tokens_per_tile
padding_tiles = max_num_tiles - n_tiles
# Image should now have shape (max_num_tiles, c, h, w)
padded_image = F.pad(
image, (0, 0, 0, 0, 0, 0, 0, padding_tiles), value=0
)
sample_images.append(padded_image)
# Stack multiple images and masks per sample in num_images dimension
batch_images.append(torch.stack(sample_images))
batch_aspect_ratios.append(
torch.stack(sample["encoder_input"]["aspect_ratio"])
)
# Finally, pad images, masks, aspect ratios to max number of images in batch
# (bsz, max_num_images, max_num_tiles, c, h, w)
collated_images = pad_sequence(batch_images, batch_first=True, padding_value=0)
# (bsz, max_num_images, 2)
collated_aspect_ratios = pad_sequence(
batch_aspect_ratios, batch_first=True, padding_value=1
)
batch_dict = {
"input_ids": collated_text["input_ids"],
"labels": collated_text["labels"],
"encoder_input": {
"images": collated_images,
"aspect_ratio": collated_aspect_ratios,
},
}
return batch_dict