MTVCrafter / utils.py
yanboding's picture
Upload 32 files
30a0a93 verified
raw
history blame
3.09 kB
import torch
import random
import numpy as np
from PIL import Image
def concat_images(images, direction='horizontal', pad=0, pad_value=0):
if len(images) == 1:
return images[0]
is_pil = isinstance(images[0], Image.Image)
if is_pil:
images = [np.array(image) for image in images]
if direction == 'horizontal':
height = max([image.shape[0] for image in images])
width = sum([image.shape[1] for image in images]) + pad * (len(images) - 1)
new_image = np.full((height, width, images[0].shape[2]), pad_value, dtype=images[0].dtype)
begin = 0
for image in images:
end = begin + image.shape[1]
new_image[: image.shape[0], begin:end] = image
begin = end + pad
elif direction == 'vertical':
height = sum([image.shape[0] for image in images]) + pad * (len(images) - 1)
width = max([image.shape[1] for image in images])
new_image = np.full((height, width, images[0].shape[2]), pad_value, dtype=images[0].dtype)
begin = 0
for image in images:
end = begin + image.shape[0]
new_image[begin:end, : image.shape[1]] = image
begin = end + pad
else:
assert False
if is_pil:
new_image = Image.fromarray(new_image)
return new_image
def concat_images_grid(images, cols, pad=0, pad_value=0):
new_images = []
while len(images) > 0:
new_image = concat_images(images[:cols], pad=pad, pad_value=pad_value)
new_images.append(new_image)
images = images[cols:]
new_image = concat_images(new_images, direction='vertical', pad=pad, pad_value=pad_value)
return new_image
def sample_video(video, indexes, method=2):
if method == 1:
frames = video.get_batch(indexes)
frames = frames.numpy() if isinstance(frames, torch.Tensor) else frames.asnumpy()
elif method == 2:
max_idx = indexes.max() + 1
all_indexes = np.arange(max_idx, dtype=int)
frames = video.get_batch(all_indexes)
frames = frames.numpy() if isinstance(frames, torch.Tensor) else frames.asnumpy()
frames = frames[indexes]
else:
assert False
return frames
def get_sample_indexes(video_length, num_frames, stride):
assert num_frames * stride <= video_length
sample_length = min(video_length, (num_frames - 1) * stride + 1)
start_idx = 0 + random.randint(0, video_length - sample_length)
sample_indexes = np.linspace(start_idx, start_idx + sample_length - 1, num_frames, dtype=int)
return sample_indexes
def get_new_height_width(data_dict, dst_height, dst_width):
height = data_dict['video_height']
width = data_dict['video_width']
if float(dst_height) / height < float(dst_width) / width:
new_height = int(round(float(dst_width) / width * height))
new_width = dst_width
else:
new_height = dst_height
new_width = int(round(float(dst_height) / height * width))
assert dst_width <= new_width and dst_height <= new_height
return new_height, new_width