|
from PIL import Image |
|
|
|
|
|
from PIL import Image |
|
from torchvision import transforms |
|
|
|
|
|
def _convert_to_rgb(image): |
|
return image.convert("RGB") |
|
|
|
|
|
visualization_preprocess = transforms.Compose( |
|
[ |
|
transforms.Resize(size=224, interpolation=Image.BICUBIC), |
|
transforms.CenterCrop(size=(224, 224)), |
|
_convert_to_rgb, |
|
] |
|
) |
|
|
|
|
|
def image_grid(imgs, rows, cols): |
|
assert len(imgs) == rows * cols |
|
|
|
w, h = imgs[0].size |
|
grid = Image.new("RGB", size=(cols * w, rows * h)) |
|
grid_w, grid_h = grid.size |
|
|
|
for i, img in enumerate(imgs): |
|
grid.paste(img, box=(i % cols * w, i // cols * h)) |
|
return grid |
|
|