alignedthreeattn / utils /visualization.py
huzey's picture
upload
7acde1f
raw
history blame contribute delete
649 Bytes
from PIL import Image
## Imports
from PIL import Image
from torchvision import transforms
def _convert_to_rgb(image):
return image.convert("RGB")
visualization_preprocess = transforms.Compose(
[
transforms.Resize(size=224, interpolation=Image.BICUBIC),
transforms.CenterCrop(size=(224, 224)),
_convert_to_rgb,
]
)
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid