|
import torch |
|
import numpy as np |
|
from torchvision import transforms |
|
from PIL import Image, ImageFilter |
|
import random |
|
|
|
|
|
def set_random_seed(seed=0): |
|
torch.manual_seed(seed + 0) |
|
torch.cuda.manual_seed(seed + 1) |
|
torch.cuda.manual_seed_all(seed + 2) |
|
np.random.seed(seed + 3) |
|
torch.cuda.manual_seed_all(seed + 4) |
|
random.seed(seed + 5) |
|
|
|
|
|
def transform_img(image, target_size=512): |
|
tform = transforms.Compose( |
|
[ |
|
transforms.Resize(target_size), |
|
transforms.CenterCrop(target_size), |
|
transforms.ToTensor(), |
|
] |
|
) |
|
image = tform(image) |
|
return 2.0 * image - 1.0 |
|
|
|
|
|
def latents_to_imgs(pipe, latents): |
|
x = pipe.decode_image(latents) |
|
x = pipe.torch_to_numpy(x) |
|
x = pipe.numpy_to_pil(x) |
|
return x |
|
|
|
def image_distortion(img, |
|
seed: int = 42, |
|
random_crop_ratio: float = None, |
|
random_drop_ratio: float = None, |
|
resize_ratio: float = None, |
|
gaussian_blur_r: int = None, |
|
gaussian_std: float = None, |
|
sp_prob: float = None): |
|
|
|
if random_crop_ratio is not None: |
|
set_random_seed(seed) |
|
width, height, c = np.array(img).shape |
|
img = np.array(img) |
|
new_width = int(width * random_crop_ratio) |
|
new_height = int(height * random_crop_ratio) |
|
start_x = np.random.randint(0, width - new_width + 1) |
|
start_y = np.random.randint(0, height - new_height + 1) |
|
end_x = start_x + new_width |
|
end_y = start_y + new_height |
|
padded_image = np.zeros_like(img) |
|
padded_image[start_y:end_y, start_x:end_x] = img[start_y:end_y, start_x:end_x] |
|
img = Image.fromarray(padded_image) |
|
|
|
if random_drop_ratio is not None: |
|
set_random_seed(seed) |
|
width, height, c = np.array(img).shape |
|
img = np.array(img) |
|
new_width = int(width * random_drop_ratio) |
|
new_height = int(height * random_drop_ratio) |
|
start_x = np.random.randint(0, width - new_width + 1) |
|
start_y = np.random.randint(0, height - new_height + 1) |
|
padded_image = np.zeros_like(img[start_y:start_y + new_height, start_x:start_x + new_width]) |
|
img[start_y:start_y + new_height, start_x:start_x + new_width] = padded_image |
|
img = Image.fromarray(img) |
|
|
|
if resize_ratio is not None: |
|
img_shape = np.array(img).shape |
|
resize_size = int(img_shape[0] * resize_ratio) |
|
img = transforms.Resize(size=resize_size)(img) |
|
img = transforms.Resize(size=img_shape[0])(img) |
|
|
|
if gaussian_blur_r is not None: |
|
img = img.filter(ImageFilter.GaussianBlur(radius=gaussian_blur_r)) |
|
|
|
if gaussian_std is not None: |
|
img_shape = np.array(img).shape |
|
g_noise = np.random.normal(0, gaussian_std, img_shape) * 255 |
|
g_noise = g_noise.astype(np.uint8) |
|
img = Image.fromarray(np.clip(np.array(img) + g_noise, 0, 255)) |
|
|
|
if sp_prob is not None: |
|
c,h,w = np.array(img).shape |
|
prob_zero = sp_prob / 2 |
|
prob_one = 1 - prob_zero |
|
rdn = np.random.rand(c,h,w) |
|
img = np.where(rdn > prob_one, np.zeros_like(img), img) |
|
img = np.where(rdn < prob_zero, np.ones_like(img)*255, img) |
|
img = Image.fromarray(img) |
|
|
|
return img |
|
|
|
def measure_similarity(images, prompt, model, clip_preprocess, tokenizer, device): |
|
with torch.no_grad(): |
|
img_batch = [clip_preprocess(i).unsqueeze(0) for i in images] |
|
img_batch = torch.concatenate(img_batch).to(device) |
|
image_features = model.encode_image(img_batch) |
|
|
|
text = tokenizer([prompt]).to(device) |
|
text_features = model.encode_text(text) |
|
|
|
image_features /= image_features.norm(dim=-1, keepdim=True) |
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
|
|
return (image_features @ text_features.T).mean(-1) |
|
|