Spaces:
Configuration error
Configuration error
Erwann Millon
commited on
Commit
·
a23872f
0
Parent(s):
initial commit
Browse files- .gitattributes +1 -0
- .gitignore +9 -0
- .gitmodules +3 -0
- ImageState.py +192 -0
- animation.py +34 -0
- app.py +195 -0
- app_backend.py +243 -0
- configs.py +7 -0
- edit.py +69 -0
- img_processing.py +72 -0
- loaders.py +97 -0
- masking.py +32 -0
- prompts.py +17 -0
- requirements.txt +27 -0
- unwrapped.yaml +37 -0
- utils.py +18 -0
- vqgan_latent_ops.py +14 -0
.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
vqgan_only.pt filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
__pycache__/*
|
3 |
+
img_history
|
4 |
+
aligned
|
5 |
+
wandb
|
6 |
+
women
|
7 |
+
men
|
8 |
+
gradio_cache*
|
9 |
+
.DS_Store
|
.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "taming"]
|
2 |
+
path = taming
|
3 |
+
url = https://github.com/CompVis/taming-transformers.git
|
ImageState.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from align import align_from_path
|
2 |
+
from animation import clear_img_dir
|
3 |
+
from app_backend import ImagePromptOptimizer, log
|
4 |
+
from functools import cache
|
5 |
+
import importlib
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import torch
|
10 |
+
import torchvision
|
11 |
+
import wandb
|
12 |
+
from icecream import ic
|
13 |
+
from torch import nn
|
14 |
+
from torchvision.transforms.functional import resize
|
15 |
+
from tqdm import tqdm
|
16 |
+
from transformers import CLIPModel, CLIPProcessor
|
17 |
+
import lpips
|
18 |
+
from app_backend import get_resized_tensor
|
19 |
+
from edit import blend_paths
|
20 |
+
from img_processing import *
|
21 |
+
from img_processing import custom_to_pil
|
22 |
+
from loaders import load_default
|
23 |
+
|
24 |
+
num = 0
|
25 |
+
class PromptTransformHistory():
|
26 |
+
def __init__(self, iterations) -> None:
|
27 |
+
self.iterations = iterations
|
28 |
+
self.transforms = []
|
29 |
+
|
30 |
+
class ImageState:
|
31 |
+
def __init__(self, vqgan, prompt_optimizer: ImagePromptOptimizer) -> None:
|
32 |
+
self.vqgan = vqgan
|
33 |
+
self.device = vqgan.device
|
34 |
+
self.blend_latent = None
|
35 |
+
self.quant = True
|
36 |
+
self.path1 = None
|
37 |
+
self.path2 = None
|
38 |
+
self.transform_history = []
|
39 |
+
self.attn_mask = None
|
40 |
+
self.prompt_optim = prompt_optimizer
|
41 |
+
self._load_vectors()
|
42 |
+
self.init_transforms()
|
43 |
+
def _load_vectors(self):
|
44 |
+
self.lip_vector = torch.load("./latent_vectors/lipvector.pt", map_location=self.device)
|
45 |
+
self.red_blue_vector = torch.load("./latent_vectors/2blue_eyes.pt", map_location=self.device)
|
46 |
+
self.green_purple_vector = torch.load("./latent_vectors/nose_vector.pt", map_location=self.device)
|
47 |
+
self.asian_vector = torch.load("./latent_vectors/asian10.pt", map_location=self.device)
|
48 |
+
def init_transforms(self):
|
49 |
+
self.blue_eyes = torch.zeros_like(self.lip_vector)
|
50 |
+
self.lip_size = torch.zeros_like(self.lip_vector)
|
51 |
+
self.asian_transform = torch.zeros_like(self.lip_vector)
|
52 |
+
self.current_prompt_transforms = [torch.zeros_like(self.lip_vector)]
|
53 |
+
self.hair_gp = torch.zeros_like(self.lip_vector)
|
54 |
+
def clear_transforms(self):
|
55 |
+
global num
|
56 |
+
self.init_transforms()
|
57 |
+
clear_img_dir()
|
58 |
+
num = 0
|
59 |
+
return self._render_all_transformations()
|
60 |
+
def _apply_vector(self, src, vector):
|
61 |
+
new_latent = torch.lerp(src, src + vector, 1)
|
62 |
+
return new_latent
|
63 |
+
def _decode_latent_to_pil(self, latent):
|
64 |
+
current_im = self.vqgan.decode(latent.to(self.device))[0]
|
65 |
+
return custom_to_pil(current_im)
|
66 |
+
# def _get_current_vector_transforms(self):
|
67 |
+
# current_vector_transforms = (self.blue_eyes, self.lip_size, self.hair_gp, self.asian_transform, sum(self.current_prompt_transforms))
|
68 |
+
# return (self.blend_latent, current_vector_transforms)
|
69 |
+
# @cache
|
70 |
+
def get_mask(self, img, mask=None):
|
71 |
+
if img and "mask" in img and img["mask"] is not None:
|
72 |
+
attn_mask = torchvision.transforms.ToTensor()(img["mask"])
|
73 |
+
attn_mask = torch.ceil(attn_mask[0].to(self.device))
|
74 |
+
plt.imshow(attn_mask.detach().cpu(), cmap="Blues")
|
75 |
+
plt.show()
|
76 |
+
torch.save(attn_mask, "test_mask.pt")
|
77 |
+
print("mask set successfully")
|
78 |
+
# attn_mask = self.rescale_mask(attn_mask)
|
79 |
+
print(type(attn_mask))
|
80 |
+
print(attn_mask.shape)
|
81 |
+
else:
|
82 |
+
attn_mask = mask
|
83 |
+
print("mask in apply ", get_resized_tensor(attn_mask), get_resized_tensor(attn_mask).shape)
|
84 |
+
return attn_mask
|
85 |
+
def set_mask(self, img):
|
86 |
+
attn_mask = self.get_mask(img)
|
87 |
+
self.attn_mask = attn_mask
|
88 |
+
# attn_mask = torch.ones_like(img, device=self.device)
|
89 |
+
x = attn_mask.clone()
|
90 |
+
x = x.detach().cpu()
|
91 |
+
x = torch.clamp(x, -1., 1.)
|
92 |
+
x = (x + 1.)/2.
|
93 |
+
x = x.numpy()
|
94 |
+
x = (255*x).astype(np.uint8)
|
95 |
+
x = Image.fromarray(x, "L")
|
96 |
+
return x
|
97 |
+
@torch.no_grad()
|
98 |
+
def _render_all_transformations(self, return_twice=True):
|
99 |
+
global num
|
100 |
+
current_vector_transforms = (self.blue_eyes, self.lip_size, self.hair_gp, self.asian_transform, sum(self.current_prompt_transforms))
|
101 |
+
new_latent = self.blend_latent + sum(current_vector_transforms)
|
102 |
+
if self.quant:
|
103 |
+
new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device))
|
104 |
+
image = self._decode_latent_to_pil(new_latent)
|
105 |
+
img_dir = "./img_history"
|
106 |
+
if not os.path.exists(img_dir):
|
107 |
+
os.mkdir(img_dir)
|
108 |
+
image.save(f"./img_history/img_{num:06}.png")
|
109 |
+
num += 1
|
110 |
+
return (image, image) if return_twice else image
|
111 |
+
def apply_gp_vector(self, weight):
|
112 |
+
self.hair_gp = weight * self.green_purple_vector
|
113 |
+
return self._render_all_transformations()
|
114 |
+
def apply_rb_vector(self, weight):
|
115 |
+
self.blue_eyes = weight * self.red_blue_vector
|
116 |
+
return self._render_all_transformations()
|
117 |
+
def apply_lip_vector(self, weight):
|
118 |
+
self.lip_size = weight * self.lip_vector
|
119 |
+
return self._render_all_transformations()
|
120 |
+
def update_requant(self, val):
|
121 |
+
print(f"val = {val}")
|
122 |
+
self.quant = val
|
123 |
+
return self._render_all_transformations()
|
124 |
+
def apply_gender_vector(self, weight):
|
125 |
+
self.asian_transform = weight * self.asian_vector
|
126 |
+
return self._render_all_transformations()
|
127 |
+
def update_images(self, path1, path2, blend_weight):
|
128 |
+
if path1 is None and path2 is None:
|
129 |
+
return None
|
130 |
+
if path1 is None: path1 = path2
|
131 |
+
if path2 is None: path2 = path1
|
132 |
+
self.path1, self.path2 = path1, path2
|
133 |
+
# self.aligned_path1 = align_from_path(path1)
|
134 |
+
# self.aligned_path2 = align_from_path(path2)
|
135 |
+
return self.blend(blend_weight)
|
136 |
+
@torch.no_grad()
|
137 |
+
def blend(self, weight):
|
138 |
+
_, latent = blend_paths(self.vqgan, self.path1, self.path2, weight=weight, show=False, device=self.device)
|
139 |
+
self.blend_latent = latent
|
140 |
+
return self._render_all_transformations()
|
141 |
+
@torch.no_grad()
|
142 |
+
def rewind(self, index):
|
143 |
+
if not self.transform_history:
|
144 |
+
print("no history")
|
145 |
+
return self._render_all_transformations()
|
146 |
+
prompt_transform = self.transform_history[-1]
|
147 |
+
latent_index = int(index / 100 * (prompt_transform.iterations - 1))
|
148 |
+
print(latent_index)
|
149 |
+
self.current_prompt_transforms[-1] = prompt_transform.transforms[latent_index]
|
150 |
+
# print(self.current_prompt_transform)
|
151 |
+
# print(self.current_prompt_transforms.mean())
|
152 |
+
return self._render_all_transformations()
|
153 |
+
def rescale_mask(self, mask):
|
154 |
+
rep = mask.clone()
|
155 |
+
rep[mask < 0.03] = -1000000
|
156 |
+
rep[mask >= 0.03] = 1
|
157 |
+
return rep
|
158 |
+
def apply_prompts(self, positive_prompts, negative_prompts, lr, iterations, lpips_weight, reconstruction_steps):
|
159 |
+
transform_log = PromptTransformHistory(iterations + reconstruction_steps)
|
160 |
+
transform_log.transforms.append(torch.zeros_like(self.blend_latent, requires_grad=False))
|
161 |
+
self.current_prompt_transforms.append(torch.zeros_like(self.blend_latent, requires_grad=False))
|
162 |
+
if log:
|
163 |
+
wandb.init(reinit=True, project="face-editor")
|
164 |
+
wandb.config.update({"Positive Prompts": positive_prompts})
|
165 |
+
wandb.config.update({"Negative Prompts": negative_prompts})
|
166 |
+
wandb.config.update(dict(
|
167 |
+
lr=lr,
|
168 |
+
iterations=iterations,
|
169 |
+
lpips_weight=lpips_weight
|
170 |
+
))
|
171 |
+
positive_prompts = [prompt.strip() for prompt in positive_prompts.split("|")]
|
172 |
+
negative_prompts = [prompt.strip() for prompt in negative_prompts.split("|")]
|
173 |
+
self.prompt_optim.set_params(lr, iterations, lpips_weight, attn_mask=self.attn_mask, reconstruction_steps=reconstruction_steps)
|
174 |
+
for i, transform in enumerate(self.prompt_optim.optimize(self.blend_latent,
|
175 |
+
positive_prompts,
|
176 |
+
negative_prompts)):
|
177 |
+
transform_log.transforms.append(transform.clone().detach())
|
178 |
+
self.current_prompt_transforms[-1] = transform
|
179 |
+
with torch.no_grad():
|
180 |
+
image = self._render_all_transformations(return_twice=False)
|
181 |
+
if log:
|
182 |
+
wandb.log({"image": wandb.Image(image)})
|
183 |
+
yield (image, image)
|
184 |
+
if log:
|
185 |
+
wandb.finish()
|
186 |
+
self.attn_mask = None
|
187 |
+
self.transform_history.append(transform_log)
|
188 |
+
# transform = self.prompt_optim.optimize(self.blend_latent,
|
189 |
+
# positive_prompts,
|
190 |
+
# negative_prompts)
|
191 |
+
# self.prompt_transforms = transform
|
192 |
+
# return self._render_all_transformations()
|
animation.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import imageio
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
|
5 |
+
def clear_img_dir():
|
6 |
+
img_dir = "./img_history"
|
7 |
+
if not os.path.exists(img_dir):
|
8 |
+
os.mkdir(img_dir)
|
9 |
+
for filename in glob.glob(img_dir+"/*"):
|
10 |
+
os.remove(filename)
|
11 |
+
|
12 |
+
|
13 |
+
def create_gif(total_duration, extend_frames, folder="./img_history", gif_name="face_edit.gif"):
|
14 |
+
images = []
|
15 |
+
paths = glob.glob(folder + "/*")
|
16 |
+
frame_duration = total_duration / len(paths)
|
17 |
+
print(len(paths), "frame dur", frame_duration)
|
18 |
+
durations = [frame_duration] * len(paths)
|
19 |
+
if extend_frames:
|
20 |
+
durations [0] = 1.5
|
21 |
+
durations [-1] = 3
|
22 |
+
for file_name in os.listdir(folder):
|
23 |
+
if file_name.endswith('.png'):
|
24 |
+
file_path = os.path.join(folder, file_name)
|
25 |
+
images.append(imageio.imread(file_path))
|
26 |
+
# images[0] = images[0].set_meta_data({'duration': 1})
|
27 |
+
# images[-1] = images[-1].set_meta_data({'duration': 1})
|
28 |
+
imageio.mimsave(gif_name, images, duration=durations)
|
29 |
+
return gif_name
|
30 |
+
|
31 |
+
if __name__ == "__main__":
|
32 |
+
# clear_img_dir()
|
33 |
+
create_gif()
|
34 |
+
# make_animation()
|
app.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
|
5 |
+
import wandb
|
6 |
+
|
7 |
+
from configs import set_major_global, set_major_local, set_small_local
|
8 |
+
|
9 |
+
sys.path.append("taming-transformers")
|
10 |
+
import functools
|
11 |
+
|
12 |
+
import gradio as gr
|
13 |
+
from transformers import CLIPModel, CLIPProcessor
|
14 |
+
|
15 |
+
import edit
|
16 |
+
# import importlib
|
17 |
+
# importlib.reload(edit)
|
18 |
+
from app_backend import ImagePromptOptimizer, ProcessorGradientFlow
|
19 |
+
from ImageState import ImageState
|
20 |
+
from loaders import load_default
|
21 |
+
from animation import create_gif
|
22 |
+
from prompts import get_random_prompts
|
23 |
+
|
24 |
+
device = "cuda"
|
25 |
+
vqgan = load_default(device)
|
26 |
+
vqgan.eval()
|
27 |
+
processor = ProcessorGradientFlow(device=device)
|
28 |
+
clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
29 |
+
clip.to(device)
|
30 |
+
promptoptim = ImagePromptOptimizer(vqgan, clip, processor, quantize=True)
|
31 |
+
state = ImageState(vqgan, promptoptim)
|
32 |
+
def set_img_from_example(img):
|
33 |
+
return state.update_images(img, img, 0)
|
34 |
+
def get_cleared_mask():
|
35 |
+
return gr.Image.update(value=None)
|
36 |
+
# mask.clear()
|
37 |
+
with gr.Blocks(css="styles.css") as demo:
|
38 |
+
with gr.Row():
|
39 |
+
with gr.Column(scale=1):
|
40 |
+
blue_eyes = gr.Slider(
|
41 |
+
label="Blue Eyes",
|
42 |
+
minimum=-.8,
|
43 |
+
maximum=3.,
|
44 |
+
value=0,
|
45 |
+
step=0.1,
|
46 |
+
)
|
47 |
+
# hair_green_purple = gr.Slider(
|
48 |
+
# label="hair green<->purple ",
|
49 |
+
# minimum=-.8,
|
50 |
+
# maximum=.8,
|
51 |
+
# value=0,
|
52 |
+
# step=0.1,
|
53 |
+
# )
|
54 |
+
lip_size = gr.Slider(
|
55 |
+
label="Lip Size",
|
56 |
+
minimum=-1.9,
|
57 |
+
value=0,
|
58 |
+
maximum=1.9,
|
59 |
+
step=0.1,
|
60 |
+
)
|
61 |
+
blend_weight = gr.Slider(
|
62 |
+
label="Blend faces: 0 is base image, 1 is the second img",
|
63 |
+
minimum=-0.,
|
64 |
+
value=0,
|
65 |
+
maximum=1.,
|
66 |
+
step=0.1,
|
67 |
+
)
|
68 |
+
# requantize = gr.Checkbox(
|
69 |
+
# label="Requantize Latents (necessary using text prompts)",
|
70 |
+
# value=True,
|
71 |
+
# )
|
72 |
+
asian_weight = gr.Slider(
|
73 |
+
minimum=-2.,
|
74 |
+
value=0,
|
75 |
+
label="Asian",
|
76 |
+
maximum=2.,
|
77 |
+
step=0.07,
|
78 |
+
)
|
79 |
+
with gr.Row():
|
80 |
+
with gr.Column():
|
81 |
+
gr.Markdown(value="""## Image Upload
|
82 |
+
For best results, crop the photos like in the example pictures""", show_label=False)
|
83 |
+
with gr.Row():
|
84 |
+
base_img = gr.Image(label="Base Image", type="filepath")
|
85 |
+
blend_img = gr.Image(label="Image for face blending (optional)", type="filepath")
|
86 |
+
# gr.Markdown("## Image Examples")
|
87 |
+
with gr.Accordion(label="Add Mask", open=False):
|
88 |
+
mask = gr.Image(tool="sketch", interactive=True)
|
89 |
+
gr.Markdown(value="Note: You must clear the mask using the rewind button every time you want to change the mask (this is a gradio bug)")
|
90 |
+
set_mask = gr.Button(value="Set mask")
|
91 |
+
gr.Text(value="this image shows the mask passed to the model when you press set mask (debugging purposes)")
|
92 |
+
testim = gr.Image()
|
93 |
+
clear_mask = gr.Button(value="Clear mask")
|
94 |
+
clear_mask.click(get_cleared_mask, outputs=mask)
|
95 |
+
with gr.Row():
|
96 |
+
gr.Examples(
|
97 |
+
examples=glob.glob("test_pics/*"),
|
98 |
+
inputs=base_img,
|
99 |
+
outputs=blend_img,
|
100 |
+
fn=set_img_from_example,
|
101 |
+
# cache_examples=True,
|
102 |
+
)
|
103 |
+
with gr.Column(scale=1):
|
104 |
+
out = gr.Image()
|
105 |
+
rewind = gr.Slider(value=100,
|
106 |
+
label="Rewind back through a prompt transform: Use this to scroll through the iterations of your prompt transformation.",
|
107 |
+
minimum=0,
|
108 |
+
maximum=100)
|
109 |
+
|
110 |
+
apply_prompts = gr.Button(value="Apply Prompts", elem_id="apply")
|
111 |
+
clear = gr.Button(value="Clear all transformations (irreversible)", elem_id="warning")
|
112 |
+
with gr.Accordion(label="Save Animation", open=False):
|
113 |
+
gr.Text(value="Creates an animation of all the steps in the editing process", show_label=False)
|
114 |
+
duration = gr.Number(value=10, label="Duration of the animation in seconds")
|
115 |
+
extend_frames = gr.Checkbox(value=True, label="Make first and last frame longer")
|
116 |
+
gif = gr.File(interactive=False)
|
117 |
+
create_animation = gr.Button(value="Create Animation")
|
118 |
+
create_animation.click(create_gif, inputs=[duration, extend_frames], outputs=gif)
|
119 |
+
|
120 |
+
with gr.Column(scale=1):
|
121 |
+
gr.Markdown(value="""## Text Prompting
|
122 |
+
See readme for a prompting guide. Use the '|' symbol to separate prompts. Use the "Add mask" section to make local edits. Negative prompts are highly recommended""", show_label=False)
|
123 |
+
positive_prompts = gr.Textbox(label="Positive prompts",
|
124 |
+
value="a picture of a woman with a very big nose | a picture of a woman with a large wide nose | a woman with an extremely prominent nose")
|
125 |
+
negative_prompts = gr.Textbox(label="Negative prompts",
|
126 |
+
value="a picture of a person with a tiny nose | a picture of a person with a very thin nose")
|
127 |
+
gen_prompts = gr.Button(value="🎲 Random prompts")
|
128 |
+
gen_prompts.click(get_random_prompts, outputs=[positive_prompts, negative_prompts])
|
129 |
+
with gr.Row():
|
130 |
+
with gr.Column():
|
131 |
+
gr.Text(value="Prompt Editing Configuration", show_label=False)
|
132 |
+
with gr.Row():
|
133 |
+
gr.Markdown(value="## Preset Configs", show_label=False)
|
134 |
+
with gr.Row():
|
135 |
+
with gr.Column():
|
136 |
+
small_local = gr.Button(value="Small Masked Changes (e.g. add lipstick)", elem_id="small_local").style(full_width=False)
|
137 |
+
with gr.Column():
|
138 |
+
major_local = gr.Button(value="Major Masked Changes (e.g. change hair color or nose size)").style(full_width=False)
|
139 |
+
with gr.Column():
|
140 |
+
major_global = gr.Button(value="Major Global Changes (e.g. change race / gender").style(full_width=False)
|
141 |
+
iterations = gr.Slider(minimum=10,
|
142 |
+
maximum=300,
|
143 |
+
step=1,
|
144 |
+
value=20,
|
145 |
+
label="Iterations: How many steps the model will take to modify the image. Try starting small and seeing how the results turn out, you can always resume with afterwards",)
|
146 |
+
learning_rate = gr.Slider(minimum=1e-3,
|
147 |
+
maximum=6e-1,
|
148 |
+
value=1e-2,
|
149 |
+
label="Learning Rate: How strong the change in each step will be (you should raise this for bigger changes (for example, changing hair color), and lower it for more minor changes. Raise if changes aren't strong enough")
|
150 |
+
with gr.Accordion(label="Advanced Prompt Editing Options", open=False):
|
151 |
+
lpips_weight = gr.Slider(minimum=0,
|
152 |
+
maximum=50,
|
153 |
+
value=1,
|
154 |
+
label="Perceptual similarity weight (Keeps areas outside of the mask looking similar to the original. Increase if the rest of the image is changing too much while you're trying to change make a localized edit")
|
155 |
+
reconstruction_steps = gr.Slider(minimum=0,
|
156 |
+
maximum=50,
|
157 |
+
value=15,
|
158 |
+
step=1,
|
159 |
+
label="Steps to run at the end of the optimization, optimizing only the masked perceptual loss. If the edit is changing the identity too much, this setting will run steps at the end that will 'pull' the image back towards the original identity")
|
160 |
+
# discriminator_steps = gr.Slider(minimum=0,
|
161 |
+
# maximum=50,
|
162 |
+
# step=1,
|
163 |
+
# value=0,
|
164 |
+
# label="Steps to run at the end, optimizing only the discriminator loss. This helps to reduce artefacts, but because the model is trained on CelebA, this will make your generations look more like generic white celebrities")
|
165 |
+
clear.click(state.clear_transforms, outputs=[out, mask])
|
166 |
+
asian_weight.change(state.apply_gender_vector, inputs=[asian_weight], outputs=[out, mask])
|
167 |
+
lip_size.change(state.apply_lip_vector, inputs=[lip_size], outputs=[out, mask])
|
168 |
+
# hair_green_purple.change(state.apply_gp_vector, inputs=[hair_green_purple], outputs=[out, mask])
|
169 |
+
blue_eyes.change(state.apply_rb_vector, inputs=[blue_eyes], outputs=[out, mask])
|
170 |
+
|
171 |
+
blend_weight.change(state.blend, inputs=[blend_weight], outputs=[out, mask])
|
172 |
+
# requantize.change(state.update_requant, inputs=[requantize], outputs=[out, mask])
|
173 |
+
|
174 |
+
|
175 |
+
base_img.change(state.update_images, inputs=[base_img, blend_img, blend_weight], outputs=[out, mask])
|
176 |
+
blend_img.change(state.update_images, inputs=[base_img, blend_img, blend_weight], outputs=[out, mask])
|
177 |
+
|
178 |
+
small_local.click(set_small_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
|
179 |
+
major_local.click(set_major_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
|
180 |
+
small_local.click(set_major_global, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
|
181 |
+
apply_prompts.click(state.apply_prompts, inputs=[positive_prompts, negative_prompts, learning_rate, iterations, lpips_weight, reconstruction_steps], outputs=[out, mask])
|
182 |
+
rewind.change(state.rewind, inputs=[rewind], outputs=[out, mask])
|
183 |
+
set_mask.click(state.set_mask, inputs=mask, outputs=testim)
|
184 |
+
demo.queue()
|
185 |
+
demo.launch(debug=True, inbrowser=True)
|
186 |
+
# if __name__ == "__main__":
|
187 |
+
# import argparse
|
188 |
+
# parser = argparse.ArgumentParser()
|
189 |
+
# parser.add_argument('--debug', action='store_true', default=False, help='Enable debugging output')
|
190 |
+
# args = parser.parse_args()
|
191 |
+
# # if args.debug:
|
192 |
+
# # state=None
|
193 |
+
# # promptoptim=None
|
194 |
+
# # else:
|
195 |
+
# main()
|
app_backend.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from functools import cache
|
2 |
+
import importlib
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
import wandb
|
9 |
+
from icecream import ic
|
10 |
+
from torch import nn
|
11 |
+
from torchvision.transforms.functional import resize
|
12 |
+
from tqdm import tqdm
|
13 |
+
from transformers import CLIPModel, CLIPProcessor
|
14 |
+
import lpips
|
15 |
+
from edit import blend_paths
|
16 |
+
from img_processing import *
|
17 |
+
from img_processing import custom_to_pil
|
18 |
+
from loaders import load_default
|
19 |
+
import glob
|
20 |
+
# global log
|
21 |
+
log=False
|
22 |
+
|
23 |
+
# ic.disable()
|
24 |
+
# ic.enable()
|
25 |
+
def get_resized_tensor(x):
|
26 |
+
if len(x.shape) == 2:
|
27 |
+
re = x.unsqueeze(0)
|
28 |
+
else: re = x
|
29 |
+
re = resize(re, (10, 10))
|
30 |
+
return re
|
31 |
+
class ProcessorGradientFlow():
|
32 |
+
"""
|
33 |
+
This wraps the huggingface CLIP processor to allow backprop through the image processing step.
|
34 |
+
The original processor forces conversion to PIL images, which breaks gradient flow.
|
35 |
+
"""
|
36 |
+
def __init__(self, device="cuda") -> None:
|
37 |
+
self.device = device
|
38 |
+
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
39 |
+
self.image_mean = [0.48145466, 0.4578275, 0.40821073]
|
40 |
+
self.image_std = [0.26862954, 0.26130258, 0.27577711]
|
41 |
+
self.normalize = torchvision.transforms.Normalize(
|
42 |
+
self.image_mean,
|
43 |
+
self.image_std
|
44 |
+
)
|
45 |
+
self.resize = torchvision.transforms.Resize(224)
|
46 |
+
self.center_crop = torchvision.transforms.CenterCrop(224)
|
47 |
+
def preprocess_img(self, images):
|
48 |
+
images = self.center_crop(images)
|
49 |
+
images = self.resize(images)
|
50 |
+
images = self.center_crop(images)
|
51 |
+
images = self.normalize(images)
|
52 |
+
return images
|
53 |
+
def __call__(self, images=[], **kwargs):
|
54 |
+
processed_inputs = self.processor(**kwargs)
|
55 |
+
processed_inputs["pixel_values"] = self.preprocess_img(images)
|
56 |
+
processed_inputs = {key:value.to(self.device) for (key, value) in processed_inputs.items()}
|
57 |
+
return processed_inputs
|
58 |
+
|
59 |
+
class ImagePromptOptimizer(nn.Module):
|
60 |
+
def __init__(self,
|
61 |
+
vqgan,
|
62 |
+
clip,
|
63 |
+
clip_preprocessor,
|
64 |
+
iterations=100,
|
65 |
+
lr = 0.01,
|
66 |
+
save_vector=True,
|
67 |
+
return_val="vector",
|
68 |
+
quantize=True,
|
69 |
+
make_grid=False,
|
70 |
+
lpips_weight = 6.2) -> None:
|
71 |
+
|
72 |
+
super().__init__()
|
73 |
+
self.latent = None
|
74 |
+
self.device = vqgan.device
|
75 |
+
vqgan.eval()
|
76 |
+
self.vqgan = vqgan
|
77 |
+
self.clip = clip
|
78 |
+
self.iterations = iterations
|
79 |
+
self.lr = lr
|
80 |
+
self.clip_preprocessor = clip_preprocessor
|
81 |
+
self.make_grid = make_grid
|
82 |
+
self.return_val = return_val
|
83 |
+
self.quantize = quantize
|
84 |
+
self.disc = load_disc(self.device)
|
85 |
+
self.lpips_weight = lpips_weight
|
86 |
+
self.perceptual_loss = lpips.LPIPS(net='vgg').to(self.device)
|
87 |
+
def disc_loss_fn(self, logits):
|
88 |
+
return -torch.mean(logits)
|
89 |
+
def set_latent(self, latent):
|
90 |
+
self.latent = latent.detach().to(self.device)
|
91 |
+
def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
|
92 |
+
self.attn_mask = attn_mask
|
93 |
+
self.iterations = iterations
|
94 |
+
self.lr = lr
|
95 |
+
self.lpips_weight = lpips_weight
|
96 |
+
self.reconstruction_steps = reconstruction_steps
|
97 |
+
def forward(self, vector):
|
98 |
+
base_latent = self.latent.detach().requires_grad_()
|
99 |
+
trans_latent = base_latent + vector
|
100 |
+
if self.quantize:
|
101 |
+
z_q, *_ = self.vqgan.quantize(trans_latent)
|
102 |
+
else:
|
103 |
+
z_q = trans_latent
|
104 |
+
dec = self.vqgan.decode(z_q)
|
105 |
+
return dec
|
106 |
+
def _get_clip_similarity(self, prompts, image, weights=None):
|
107 |
+
if isinstance(prompts, str):
|
108 |
+
prompts = [prompts]
|
109 |
+
elif not isinstance(prompts, list):
|
110 |
+
raise TypeError("Provide prompts as string or list of strings")
|
111 |
+
clip_inputs = self.clip_preprocessor(text=prompts,
|
112 |
+
images=image, return_tensors="pt", padding=True)
|
113 |
+
clip_outputs = self.clip(**clip_inputs)
|
114 |
+
similarity_logits = clip_outputs.logits_per_image
|
115 |
+
if weights:
|
116 |
+
similarity_logits *= weights
|
117 |
+
return similarity_logits.sum()
|
118 |
+
def get_similarity_loss(self, pos_prompts, neg_prompts, image):
|
119 |
+
pos_logits = self._get_clip_similarity(pos_prompts, image)
|
120 |
+
if neg_prompts:
|
121 |
+
neg_logits = self._get_clip_similarity(neg_prompts, image)
|
122 |
+
else:
|
123 |
+
neg_logits = torch.tensor([1], device=self.device)
|
124 |
+
loss = -torch.log(pos_logits) + torch.log(neg_logits)
|
125 |
+
return loss
|
126 |
+
def visualize(self, processed_img):
|
127 |
+
if self.make_grid:
|
128 |
+
self.index += 1
|
129 |
+
plt.subplot(1, 13, self.index)
|
130 |
+
plt.imshow(get_pil(processed_img[0]).detach().cpu())
|
131 |
+
else:
|
132 |
+
plt.imshow(get_pil(processed_img[0]).detach().cpu())
|
133 |
+
plt.show()
|
134 |
+
def attn_masking(self, grad):
|
135 |
+
# print("attnmask 1")
|
136 |
+
# print(f"input grad.shape = {grad.shape}")
|
137 |
+
# print(f"input grad = {get_resized_tensor(grad)}")
|
138 |
+
newgrad = grad
|
139 |
+
if self.attn_mask is not None:
|
140 |
+
# print("masking mult")
|
141 |
+
newgrad = grad * (self.attn_mask)
|
142 |
+
# print("output grad, ", get_resized_tensor(newgrad))
|
143 |
+
# print("end atn 1")
|
144 |
+
return newgrad
|
145 |
+
def attn_masking2(self, grad):
|
146 |
+
# print("attnmask 2")
|
147 |
+
# print(f"input grad.shape = {grad.shape}")
|
148 |
+
# print(f"input grad = {get_resized_tensor(grad)}")
|
149 |
+
newgrad = grad
|
150 |
+
if self.attn_mask is not None:
|
151 |
+
# print("masking mult")
|
152 |
+
newgrad = grad * ((self.attn_mask - 1) * -1)
|
153 |
+
# print("output grad, ", get_resized_tensor(newgrad))
|
154 |
+
# print("end atn 2")
|
155 |
+
return newgrad
|
156 |
+
|
157 |
+
def optimize(self, latent, pos_prompts, neg_prompts):
|
158 |
+
self.set_latent(latent)
|
159 |
+
# self.make_grid=True
|
160 |
+
transformed_img = self(torch.zeros_like(self.latent, requires_grad=True, device=self.device))
|
161 |
+
original_img = loop_post_process(transformed_img)
|
162 |
+
vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
|
163 |
+
optim = torch.optim.Adam([vector], lr=self.lr)
|
164 |
+
if self.make_grid:
|
165 |
+
plt.figure(figsize=(35, 25))
|
166 |
+
self.index = 1
|
167 |
+
for i in tqdm(range(self.iterations)):
|
168 |
+
optim.zero_grad()
|
169 |
+
transformed_img = self(vector)
|
170 |
+
processed_img = loop_post_process(transformed_img) #* self.attn_mask
|
171 |
+
processed_img.retain_grad()
|
172 |
+
lpips_input = processed_img.clone()
|
173 |
+
lpips_input.register_hook(self.attn_masking2)
|
174 |
+
lpips_input.retain_grad()
|
175 |
+
clip_clone = processed_img.clone()
|
176 |
+
clip_clone.register_hook(self.attn_masking)
|
177 |
+
clip_clone.retain_grad()
|
178 |
+
with torch.autocast("cuda"):
|
179 |
+
clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts, clip_clone)
|
180 |
+
print("CLIP loss", clip_loss)
|
181 |
+
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
182 |
+
print("LPIPS loss: ", perceptual_loss)
|
183 |
+
with torch.no_grad():
|
184 |
+
disc_logits = self.disc(transformed_img)
|
185 |
+
disc_loss = self.disc_loss_fn(disc_logits)
|
186 |
+
print(f"disc_loss = {disc_loss}")
|
187 |
+
disc_loss2 = self.disc(processed_img)
|
188 |
+
if log:
|
189 |
+
wandb.log({"Perceptual Loss": perceptual_loss})
|
190 |
+
wandb.log({"Discriminator Loss": disc_loss})
|
191 |
+
wandb.log({"CLIP Loss": clip_loss})
|
192 |
+
clip_loss.backward(retain_graph=True)
|
193 |
+
perceptual_loss.backward(retain_graph=True)
|
194 |
+
p2 = processed_img.grad
|
195 |
+
print("Sum Loss", perceptual_loss + clip_loss)
|
196 |
+
optim.step()
|
197 |
+
# if i % self.iterations // 10 == 0:
|
198 |
+
# self.visualize(transformed_img)
|
199 |
+
yield vector
|
200 |
+
if self.make_grid:
|
201 |
+
plt.savefig(f"plot {pos_prompts[0]}.png")
|
202 |
+
plt.show()
|
203 |
+
print("lpips solo op")
|
204 |
+
for i in range(self.reconstruction_steps):
|
205 |
+
optim.zero_grad()
|
206 |
+
transformed_img = self(vector)
|
207 |
+
processed_img = loop_post_process(transformed_img) #* self.attn_mask
|
208 |
+
processed_img.retain_grad()
|
209 |
+
lpips_input = processed_img.clone()
|
210 |
+
lpips_input.register_hook(self.attn_masking2)
|
211 |
+
lpips_input.retain_grad()
|
212 |
+
with torch.autocast("cuda"):
|
213 |
+
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
214 |
+
with torch.no_grad():
|
215 |
+
disc_logits = self.disc(transformed_img)
|
216 |
+
disc_loss = self.disc_loss_fn(disc_logits)
|
217 |
+
print(f"disc_loss = {disc_loss}")
|
218 |
+
disc_loss2 = self.disc(processed_img)
|
219 |
+
# print(f"disc_loss2 = {disc_loss2}")
|
220 |
+
if log:
|
221 |
+
wandb.log({"Perceptual Loss": perceptual_loss})
|
222 |
+
print("LPIPS loss: ", perceptual_loss)
|
223 |
+
perceptual_loss.backward(retain_graph=True)
|
224 |
+
optim.step()
|
225 |
+
yield vector
|
226 |
+
# torch.save(vector, "nose_vector.pt")
|
227 |
+
# print("")
|
228 |
+
# print("DISC STEPS")
|
229 |
+
# print("*************")
|
230 |
+
# for i in range(self.reconstruction_steps):
|
231 |
+
# optim.zero_grad()
|
232 |
+
# transformed_img = self(vector)
|
233 |
+
# processed_img = loop_post_process(transformed_img) #* self.attn_mask
|
234 |
+
# disc_logits = self.disc(transformed_img)
|
235 |
+
# disc_loss = self.disc_loss_fn(disc_logits)
|
236 |
+
# print(f"disc_loss = {disc_loss}")
|
237 |
+
# if log:
|
238 |
+
# wandb.log({"Disc Loss": disc_loss})
|
239 |
+
# print("LPIPS loss: ", perceptual_loss)
|
240 |
+
# disc_loss.backward(retain_graph=True)
|
241 |
+
# optim.step()
|
242 |
+
# yield vector
|
243 |
+
yield vector if self.return_val == "vector" else self.latent + vector
|
configs.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
def set_small_local():
|
3 |
+
return (gr.Slider.update(value=25), gr.Slider.update(value=0.15), gr.Slider.update(value=1), gr.Slider.update(value=4))
|
4 |
+
def set_major_local():
|
5 |
+
return (gr.Slider.update(value=25), gr.Slider.update(value=0.25), gr.Slider.update(value=35), gr.Slider.update(value=10))
|
6 |
+
def set_major_global():
|
7 |
+
return (gr.Slider.update(value=30), gr.Slider.update(value=0.1), gr.Slider.update(value=2), gr.Slider.update(value=0.2))
|
edit.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
from img_processing import custom_to_pil, preprocess, preprocess_vqgan
|
5 |
+
|
6 |
+
sys.path.append("taming-transformers")
|
7 |
+
import glob
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import PIL
|
12 |
+
import taming
|
13 |
+
import torch
|
14 |
+
|
15 |
+
from loaders import load_config
|
16 |
+
from utils import get_device
|
17 |
+
|
18 |
+
|
19 |
+
def get_embedding(model, path=None, img=None, device="cpu"):
|
20 |
+
assert path is None or img is None, "Input either path or tensor"
|
21 |
+
if img is not None:
|
22 |
+
raise NotImplementedError
|
23 |
+
x = preprocess(PIL.Image.open(path), target_image_size=256).to(device)
|
24 |
+
x_processed = preprocess_vqgan(x)
|
25 |
+
x_latent, _, [_, _, indices] = model.encode(x_processed)
|
26 |
+
return x_latent
|
27 |
+
|
28 |
+
|
29 |
+
def blend_paths(model, path1, path2, quantize=False, weight=0.5, show=True, device="cuda"):
|
30 |
+
x = preprocess(PIL.Image.open(path1), target_image_size=256).to(device)
|
31 |
+
y = preprocess(PIL.Image.open(path2), target_image_size=256).to(device)
|
32 |
+
x_latent, y_latent = get_embedding(model, path=path1, device=device), get_embedding(model, path=path2, device=device)
|
33 |
+
z = torch.lerp(x_latent, y_latent, weight)
|
34 |
+
if quantize:
|
35 |
+
z = model.quantize(z)[0]
|
36 |
+
decoded = model.decode(z)[0]
|
37 |
+
if show:
|
38 |
+
plt.figure(figsize=(10, 20))
|
39 |
+
plt.subplot(1, 3, 1)
|
40 |
+
plt.imshow(x.cpu().permute(0, 2, 3, 1)[0])
|
41 |
+
plt.subplot(1, 3, 2)
|
42 |
+
plt.imshow(custom_to_pil(decoded))
|
43 |
+
plt.subplot(1, 3, 3)
|
44 |
+
plt.imshow(y.cpu().permute(0, 2, 3, 1)[0])
|
45 |
+
plt.show()
|
46 |
+
return custom_to_pil(decoded), z
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
device = get_device()
|
50 |
+
# conf_path = "logs/2021-04-23T18-11-19_celebahq_transformer/configs/2021-04-23T18-11-19-project.yaml"
|
51 |
+
ckpt_path = "logs/2021-04-23T18-11-19_celebahq_transformer/checkpoints/last.ckpt"
|
52 |
+
# ckpt_path = "./faceshq/faceshq.pt"
|
53 |
+
conf_path = "./unwrapped.yaml"
|
54 |
+
# conf_path = "./faceshq/faceshq.yaml"
|
55 |
+
config = load_config(conf_path, display=False)
|
56 |
+
model = taming.models.vqgan.VQModel(**config.model.params)
|
57 |
+
sd = torch.load("./vqgan_only.pt", map_location="mps")
|
58 |
+
model.load_state_dict(sd, strict=True)
|
59 |
+
model.to(device)
|
60 |
+
blend_paths(model, "./test_data/face.jpeg", "./test_data/face2.jpeg", quantize=False, weight=.5)
|
61 |
+
plt.show()
|
62 |
+
|
63 |
+
demo = gr.Interface(
|
64 |
+
get_image,
|
65 |
+
inputs=gr.inputs.Image(label="UploadZz a black and white face", type="filepath"),
|
66 |
+
outputs="image",
|
67 |
+
title="Upload a black and white face and get a colorized image!",
|
68 |
+
)
|
69 |
+
|
img_processing.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import PIL
|
7 |
+
import requests
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torchvision.transforms as T
|
11 |
+
import torchvision.transforms.functional as TF
|
12 |
+
from PIL import Image, ImageDraw, ImageFont
|
13 |
+
|
14 |
+
|
15 |
+
def download_image(url):
|
16 |
+
resp = requests.get(url)
|
17 |
+
resp.raise_for_status()
|
18 |
+
return PIL.Image.open(io.BytesIO(resp.content))
|
19 |
+
|
20 |
+
|
21 |
+
def preprocess(img, target_image_size=256, map_dalle=False):
|
22 |
+
s = min(img.size)
|
23 |
+
|
24 |
+
if s < target_image_size:
|
25 |
+
raise ValueError(f'min dim for image {s} < {target_image_size}')
|
26 |
+
|
27 |
+
r = target_image_size / s
|
28 |
+
s = (round(r * img.size[1]), round(r * img.size[0]))
|
29 |
+
img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
|
30 |
+
img = TF.center_crop(img, output_size=2 * [target_image_size])
|
31 |
+
img = torch.unsqueeze(T.ToTensor()(img), 0)
|
32 |
+
return img
|
33 |
+
|
34 |
+
def preprocess_vqgan(x):
|
35 |
+
x = 2.*x - 1.
|
36 |
+
return x
|
37 |
+
|
38 |
+
def custom_to_pil(x, process=True, mode="RGB"):
|
39 |
+
x = x.detach().cpu()
|
40 |
+
if process:
|
41 |
+
x = torch.clamp(x, -1., 1.)
|
42 |
+
x = (x + 1.)/2.
|
43 |
+
x = x.permute(1,2,0).numpy()
|
44 |
+
if process:
|
45 |
+
x = (255*x).astype(np.uint8)
|
46 |
+
x = Image.fromarray(x)
|
47 |
+
if not x.mode == mode:
|
48 |
+
x = x.convert(mode)
|
49 |
+
return x
|
50 |
+
|
51 |
+
def get_pil(x):
|
52 |
+
x = torch.clamp(x, -1., 1.)
|
53 |
+
x = (x + 1.)/2.
|
54 |
+
x = x.permute(1,2,0)
|
55 |
+
return x
|
56 |
+
|
57 |
+
def loop_post_process(x):
|
58 |
+
x = get_pil(x.squeeze())
|
59 |
+
return x.permute(2, 0, 1).unsqueeze(0)
|
60 |
+
|
61 |
+
def stack_reconstructions(input, x0, x1, x2, x3, titles=[]):
|
62 |
+
assert input.size == x1.size == x2.size == x3.size
|
63 |
+
w, h = input.size[0], input.size[1]
|
64 |
+
img = Image.new("RGB", (5*w, h))
|
65 |
+
img.paste(input, (0,0))
|
66 |
+
img.paste(x0, (1*w,0))
|
67 |
+
img.paste(x1, (2*w,0))
|
68 |
+
img.paste(x2, (3*w,0))
|
69 |
+
img.paste(x3, (4*w,0))
|
70 |
+
for i, title in enumerate(titles):
|
71 |
+
ImageDraw.Draw(img).text((i*w, 0), f'{title}', (255, 255, 255), font=font) # coordinates, text, color, font
|
72 |
+
return img
|
loaders.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import taming
|
5 |
+
import torch
|
6 |
+
import yaml
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from PIL import Image
|
9 |
+
from taming.models.vqgan import VQModel
|
10 |
+
|
11 |
+
from utils import get_device
|
12 |
+
# import discriminator
|
13 |
+
|
14 |
+
def load_config(config_path, display=False):
|
15 |
+
config = OmegaConf.load(config_path)
|
16 |
+
if display:
|
17 |
+
print(yaml.dump(OmegaConf.to_container(config)))
|
18 |
+
return config
|
19 |
+
|
20 |
+
# def load_disc(device):
|
21 |
+
# dconf = load_config("disc_config.yaml")
|
22 |
+
# sd = torch.load("disc.pt", map_location=device)
|
23 |
+
# # print(sd.keys())
|
24 |
+
# model = discriminator.NLayerDiscriminator()
|
25 |
+
# model.load_state_dict(sd, strict=True)
|
26 |
+
# model.to(device)
|
27 |
+
# return model
|
28 |
+
# print(dconf.keys())
|
29 |
+
|
30 |
+
def load_default(device):
|
31 |
+
# device = get_device()
|
32 |
+
ckpt_path = "logs/2021-04-23T18-11-19_celebahq_transformer/checkpoints/last.ckpt"
|
33 |
+
conf_path = "./unwrapped.yaml"
|
34 |
+
config = load_config(conf_path, display=False)
|
35 |
+
model = taming.models.vqgan.VQModel(**config.model.params)
|
36 |
+
sd = torch.load("./vqgan_only.pt", map_location=device)
|
37 |
+
model.load_state_dict(sd, strict=True)
|
38 |
+
model.to(device)
|
39 |
+
return model
|
40 |
+
|
41 |
+
|
42 |
+
def load_vqgan(config, ckpt_path=None, is_gumbel=False):
|
43 |
+
if is_gumbel:
|
44 |
+
model = GumbelVQ(**config.model.params)
|
45 |
+
else:
|
46 |
+
model = VQModel(**config.model.params)
|
47 |
+
if ckpt_path is not None:
|
48 |
+
sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
49 |
+
missing, unexpected = model.load_state_dict(sd, strict=False)
|
50 |
+
return model.eval()
|
51 |
+
|
52 |
+
def load_ffhq():
|
53 |
+
conf = "2020-11-09T13-33-36_faceshq_vqgan/configs/2020-11-09T13-33-36-project.yaml"
|
54 |
+
ckpt = "2020-11-09T13-33-36_faceshq_vqgan/checkpoints/last.ckpt"
|
55 |
+
vqgan = load_model(load_config(conf), ckpt, True, True)[0]
|
56 |
+
|
57 |
+
def reconstruct_with_vqgan(x, model):
|
58 |
+
# could also use model(x) for reconstruction but use explicit encoding and decoding here
|
59 |
+
z, _, [_, _, indices] = model.encode(x)
|
60 |
+
print(f"VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}")
|
61 |
+
xrec = model.decode(z)
|
62 |
+
return xrec
|
63 |
+
def get_obj_from_str(string, reload=False):
|
64 |
+
module, cls = string.rsplit(".", 1)
|
65 |
+
if reload:
|
66 |
+
module_imp = importlib.import_module(module)
|
67 |
+
importlib.reload(module_imp)
|
68 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
69 |
+
|
70 |
+
def instantiate_from_config(config):
|
71 |
+
|
72 |
+
if not "target" in config:
|
73 |
+
raise KeyError("Expected key `target` to instantiate.")
|
74 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
75 |
+
|
76 |
+
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
|
77 |
+
model = instantiate_from_config(config)
|
78 |
+
if sd is not None:
|
79 |
+
model.load_state_dict(sd)
|
80 |
+
if gpu:
|
81 |
+
model.cuda()
|
82 |
+
if eval_mode:
|
83 |
+
model.eval()
|
84 |
+
return {"model": model}
|
85 |
+
|
86 |
+
|
87 |
+
def load_model(config, ckpt, gpu, eval_mode):
|
88 |
+
# load the specified checkpoint
|
89 |
+
if ckpt:
|
90 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
91 |
+
global_step = pl_sd["global_step"]
|
92 |
+
print(f"loaded model from global step {global_step}.")
|
93 |
+
else:
|
94 |
+
pl_sd = {"state_dict": None}
|
95 |
+
global_step = None
|
96 |
+
model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"]
|
97 |
+
return model, global_step
|
masking.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import torch
|
6 |
+
|
7 |
+
sys.path.append("taming-transformers")
|
8 |
+
import functools
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
from transformers import CLIPModel, CLIPProcessor
|
12 |
+
|
13 |
+
import edit
|
14 |
+
# import importlib
|
15 |
+
# importlib.reload(edit)
|
16 |
+
from app_backend import ImagePromptOptimizer, ImageState, ProcessorGradientFlow
|
17 |
+
from loaders import load_default
|
18 |
+
|
19 |
+
device = "cuda"
|
20 |
+
vqgan = load_default(device)
|
21 |
+
vqgan.eval()
|
22 |
+
processor = ProcessorGradientFlow(device=device)
|
23 |
+
clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
24 |
+
clip.to(device)
|
25 |
+
promptoptim = ImagePromptOptimizer(vqgan, clip, processor, quantize=True)
|
26 |
+
state = ImageState(vqgan, promptoptim)
|
27 |
+
mask = torch.load("eyebrow_mask.pt")
|
28 |
+
x = state.blend("./test_data/face.jpeg", "./test_data/face2.jpeg", 0.5)
|
29 |
+
plt.imshow(x)
|
30 |
+
plt.show()
|
31 |
+
state.apply_prompts("a picture of a woman with big eyebrows", "", 0.009, 40, None, mask=mask)
|
32 |
+
print('done')
|
prompts.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
class PromptSet:
|
3 |
+
def __init__(self, pos, neg, config=None):
|
4 |
+
self.positive = pos
|
5 |
+
self.negative = neg
|
6 |
+
self.config = config
|
7 |
+
example_prompts = (
|
8 |
+
PromptSet("a picture of a woman with light blonde hair", "a picture of a person with dark hair | a picture of a person with brown hair"),
|
9 |
+
PromptSet("A picture of a woman with very thick eyebrows", "a picture of a person with very thin eyebrows | a picture of a person with no eyebrows"),
|
10 |
+
PromptSet("A picture of a woman wearing bright red lipstick", "a picture of a person wearing no lipstick | a picture of a person wearing dark lipstick"),
|
11 |
+
PromptSet("A picture of a beautiful chinese woman | a picture of a Japanese woman | a picture of an Asian woman", "a picture of a white woman | a picture of an Indian woman | a picture of a black woman"),
|
12 |
+
PromptSet("A picture of a handsome man | a picture of a masculine man", "a picture of a woman | a picture of a feminine person"),
|
13 |
+
PromptSet("A picture of a woman with a very big nose", "a picture of a person with a small nose | a picture of a person with a normal nose"),
|
14 |
+
)
|
15 |
+
def get_random_prompts():
|
16 |
+
prompt = random.choice(example_prompts)
|
17 |
+
return prompt.positive, prompt.negative
|
requirements.txt
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
taming-transformers
|
2 |
+
einops
|
3 |
+
gradio
|
4 |
+
icecream
|
5 |
+
imageio
|
6 |
+
matplotlib
|
7 |
+
more_itertools
|
8 |
+
numpy
|
9 |
+
omegaconf
|
10 |
+
opencv_python_headless
|
11 |
+
Pillow
|
12 |
+
prompts
|
13 |
+
pudb
|
14 |
+
pytorch_lightning
|
15 |
+
PyYAML
|
16 |
+
requests
|
17 |
+
scikit_image
|
18 |
+
scipy
|
19 |
+
setuptools
|
20 |
+
streamlit
|
21 |
+
torch
|
22 |
+
torchvision
|
23 |
+
tqdm
|
24 |
+
transformers
|
25 |
+
typing_extensions
|
26 |
+
wandb
|
27 |
+
lpips
|
unwrapped.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: taming.models.vqgan.VQModel
|
3 |
+
params:
|
4 |
+
embed_dim: 256
|
5 |
+
n_embed: 1024
|
6 |
+
ddconfig:
|
7 |
+
double_z: false
|
8 |
+
z_channels: 256
|
9 |
+
resolution: 256
|
10 |
+
in_channels: 3
|
11 |
+
out_ch: 3
|
12 |
+
ch: 128
|
13 |
+
ch_mult:
|
14 |
+
- 1
|
15 |
+
- 1
|
16 |
+
- 2
|
17 |
+
- 2
|
18 |
+
- 4
|
19 |
+
num_res_blocks: 2
|
20 |
+
attn_resolutions:
|
21 |
+
- 16
|
22 |
+
dropout: 0.0
|
23 |
+
lossconfig:
|
24 |
+
target: taming.modules.losses.vqperceptual.DummyLoss
|
25 |
+
data:
|
26 |
+
target: cutlit.DataModuleFromConfig
|
27 |
+
params:
|
28 |
+
batch_size: 24
|
29 |
+
num_workers: 24
|
30 |
+
train:
|
31 |
+
target: taming.data.faceshq.CelebAHQTrain
|
32 |
+
params:
|
33 |
+
size: 256
|
34 |
+
validation:
|
35 |
+
target: taming.data.faceshq.CelebAHQValidation
|
36 |
+
params:
|
37 |
+
size: 256
|
utils.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from skimage.color import lab2rgb, rgb2lab
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
|
11 |
+
def freeze_module(module):
|
12 |
+
for param in module.parameters():
|
13 |
+
param.requires_grad = False
|
14 |
+
def get_device():
|
15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
+
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
17 |
+
device = "mps"
|
18 |
+
return (device)
|
vqgan_latent_ops.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from gradient_flow_ops import ReplaceGrad
|
6 |
+
|
7 |
+
replace_grad = ReplaceGrad.apply
|
8 |
+
|
9 |
+
def vector_quantize(x, codebook):
|
10 |
+
|
11 |
+
d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
|
12 |
+
indices = d.argmin(-1)
|
13 |
+
x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
|
14 |
+
return replace_grad(x_q, x)
|