Erwann Millon commited on
Commit
a23872f
·
0 Parent(s):

initial commit

Browse files
Files changed (17) hide show
  1. .gitattributes +1 -0
  2. .gitignore +9 -0
  3. .gitmodules +3 -0
  4. ImageState.py +192 -0
  5. animation.py +34 -0
  6. app.py +195 -0
  7. app_backend.py +243 -0
  8. configs.py +7 -0
  9. edit.py +69 -0
  10. img_processing.py +72 -0
  11. loaders.py +97 -0
  12. masking.py +32 -0
  13. prompts.py +17 -0
  14. requirements.txt +27 -0
  15. unwrapped.yaml +37 -0
  16. utils.py +18 -0
  17. vqgan_latent_ops.py +14 -0
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ vqgan_only.pt filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ __pycache__/*
3
+ img_history
4
+ aligned
5
+ wandb
6
+ women
7
+ men
8
+ gradio_cache*
9
+ .DS_Store
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "taming"]
2
+ path = taming
3
+ url = https://github.com/CompVis/taming-transformers.git
ImageState.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from align import align_from_path
2
+ from animation import clear_img_dir
3
+ from app_backend import ImagePromptOptimizer, log
4
+ from functools import cache
5
+ import importlib
6
+
7
+ import gradio as gr
8
+ import matplotlib.pyplot as plt
9
+ import torch
10
+ import torchvision
11
+ import wandb
12
+ from icecream import ic
13
+ from torch import nn
14
+ from torchvision.transforms.functional import resize
15
+ from tqdm import tqdm
16
+ from transformers import CLIPModel, CLIPProcessor
17
+ import lpips
18
+ from app_backend import get_resized_tensor
19
+ from edit import blend_paths
20
+ from img_processing import *
21
+ from img_processing import custom_to_pil
22
+ from loaders import load_default
23
+
24
+ num = 0
25
+ class PromptTransformHistory():
26
+ def __init__(self, iterations) -> None:
27
+ self.iterations = iterations
28
+ self.transforms = []
29
+
30
+ class ImageState:
31
+ def __init__(self, vqgan, prompt_optimizer: ImagePromptOptimizer) -> None:
32
+ self.vqgan = vqgan
33
+ self.device = vqgan.device
34
+ self.blend_latent = None
35
+ self.quant = True
36
+ self.path1 = None
37
+ self.path2 = None
38
+ self.transform_history = []
39
+ self.attn_mask = None
40
+ self.prompt_optim = prompt_optimizer
41
+ self._load_vectors()
42
+ self.init_transforms()
43
+ def _load_vectors(self):
44
+ self.lip_vector = torch.load("./latent_vectors/lipvector.pt", map_location=self.device)
45
+ self.red_blue_vector = torch.load("./latent_vectors/2blue_eyes.pt", map_location=self.device)
46
+ self.green_purple_vector = torch.load("./latent_vectors/nose_vector.pt", map_location=self.device)
47
+ self.asian_vector = torch.load("./latent_vectors/asian10.pt", map_location=self.device)
48
+ def init_transforms(self):
49
+ self.blue_eyes = torch.zeros_like(self.lip_vector)
50
+ self.lip_size = torch.zeros_like(self.lip_vector)
51
+ self.asian_transform = torch.zeros_like(self.lip_vector)
52
+ self.current_prompt_transforms = [torch.zeros_like(self.lip_vector)]
53
+ self.hair_gp = torch.zeros_like(self.lip_vector)
54
+ def clear_transforms(self):
55
+ global num
56
+ self.init_transforms()
57
+ clear_img_dir()
58
+ num = 0
59
+ return self._render_all_transformations()
60
+ def _apply_vector(self, src, vector):
61
+ new_latent = torch.lerp(src, src + vector, 1)
62
+ return new_latent
63
+ def _decode_latent_to_pil(self, latent):
64
+ current_im = self.vqgan.decode(latent.to(self.device))[0]
65
+ return custom_to_pil(current_im)
66
+ # def _get_current_vector_transforms(self):
67
+ # current_vector_transforms = (self.blue_eyes, self.lip_size, self.hair_gp, self.asian_transform, sum(self.current_prompt_transforms))
68
+ # return (self.blend_latent, current_vector_transforms)
69
+ # @cache
70
+ def get_mask(self, img, mask=None):
71
+ if img and "mask" in img and img["mask"] is not None:
72
+ attn_mask = torchvision.transforms.ToTensor()(img["mask"])
73
+ attn_mask = torch.ceil(attn_mask[0].to(self.device))
74
+ plt.imshow(attn_mask.detach().cpu(), cmap="Blues")
75
+ plt.show()
76
+ torch.save(attn_mask, "test_mask.pt")
77
+ print("mask set successfully")
78
+ # attn_mask = self.rescale_mask(attn_mask)
79
+ print(type(attn_mask))
80
+ print(attn_mask.shape)
81
+ else:
82
+ attn_mask = mask
83
+ print("mask in apply ", get_resized_tensor(attn_mask), get_resized_tensor(attn_mask).shape)
84
+ return attn_mask
85
+ def set_mask(self, img):
86
+ attn_mask = self.get_mask(img)
87
+ self.attn_mask = attn_mask
88
+ # attn_mask = torch.ones_like(img, device=self.device)
89
+ x = attn_mask.clone()
90
+ x = x.detach().cpu()
91
+ x = torch.clamp(x, -1., 1.)
92
+ x = (x + 1.)/2.
93
+ x = x.numpy()
94
+ x = (255*x).astype(np.uint8)
95
+ x = Image.fromarray(x, "L")
96
+ return x
97
+ @torch.no_grad()
98
+ def _render_all_transformations(self, return_twice=True):
99
+ global num
100
+ current_vector_transforms = (self.blue_eyes, self.lip_size, self.hair_gp, self.asian_transform, sum(self.current_prompt_transforms))
101
+ new_latent = self.blend_latent + sum(current_vector_transforms)
102
+ if self.quant:
103
+ new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device))
104
+ image = self._decode_latent_to_pil(new_latent)
105
+ img_dir = "./img_history"
106
+ if not os.path.exists(img_dir):
107
+ os.mkdir(img_dir)
108
+ image.save(f"./img_history/img_{num:06}.png")
109
+ num += 1
110
+ return (image, image) if return_twice else image
111
+ def apply_gp_vector(self, weight):
112
+ self.hair_gp = weight * self.green_purple_vector
113
+ return self._render_all_transformations()
114
+ def apply_rb_vector(self, weight):
115
+ self.blue_eyes = weight * self.red_blue_vector
116
+ return self._render_all_transformations()
117
+ def apply_lip_vector(self, weight):
118
+ self.lip_size = weight * self.lip_vector
119
+ return self._render_all_transformations()
120
+ def update_requant(self, val):
121
+ print(f"val = {val}")
122
+ self.quant = val
123
+ return self._render_all_transformations()
124
+ def apply_gender_vector(self, weight):
125
+ self.asian_transform = weight * self.asian_vector
126
+ return self._render_all_transformations()
127
+ def update_images(self, path1, path2, blend_weight):
128
+ if path1 is None and path2 is None:
129
+ return None
130
+ if path1 is None: path1 = path2
131
+ if path2 is None: path2 = path1
132
+ self.path1, self.path2 = path1, path2
133
+ # self.aligned_path1 = align_from_path(path1)
134
+ # self.aligned_path2 = align_from_path(path2)
135
+ return self.blend(blend_weight)
136
+ @torch.no_grad()
137
+ def blend(self, weight):
138
+ _, latent = blend_paths(self.vqgan, self.path1, self.path2, weight=weight, show=False, device=self.device)
139
+ self.blend_latent = latent
140
+ return self._render_all_transformations()
141
+ @torch.no_grad()
142
+ def rewind(self, index):
143
+ if not self.transform_history:
144
+ print("no history")
145
+ return self._render_all_transformations()
146
+ prompt_transform = self.transform_history[-1]
147
+ latent_index = int(index / 100 * (prompt_transform.iterations - 1))
148
+ print(latent_index)
149
+ self.current_prompt_transforms[-1] = prompt_transform.transforms[latent_index]
150
+ # print(self.current_prompt_transform)
151
+ # print(self.current_prompt_transforms.mean())
152
+ return self._render_all_transformations()
153
+ def rescale_mask(self, mask):
154
+ rep = mask.clone()
155
+ rep[mask < 0.03] = -1000000
156
+ rep[mask >= 0.03] = 1
157
+ return rep
158
+ def apply_prompts(self, positive_prompts, negative_prompts, lr, iterations, lpips_weight, reconstruction_steps):
159
+ transform_log = PromptTransformHistory(iterations + reconstruction_steps)
160
+ transform_log.transforms.append(torch.zeros_like(self.blend_latent, requires_grad=False))
161
+ self.current_prompt_transforms.append(torch.zeros_like(self.blend_latent, requires_grad=False))
162
+ if log:
163
+ wandb.init(reinit=True, project="face-editor")
164
+ wandb.config.update({"Positive Prompts": positive_prompts})
165
+ wandb.config.update({"Negative Prompts": negative_prompts})
166
+ wandb.config.update(dict(
167
+ lr=lr,
168
+ iterations=iterations,
169
+ lpips_weight=lpips_weight
170
+ ))
171
+ positive_prompts = [prompt.strip() for prompt in positive_prompts.split("|")]
172
+ negative_prompts = [prompt.strip() for prompt in negative_prompts.split("|")]
173
+ self.prompt_optim.set_params(lr, iterations, lpips_weight, attn_mask=self.attn_mask, reconstruction_steps=reconstruction_steps)
174
+ for i, transform in enumerate(self.prompt_optim.optimize(self.blend_latent,
175
+ positive_prompts,
176
+ negative_prompts)):
177
+ transform_log.transforms.append(transform.clone().detach())
178
+ self.current_prompt_transforms[-1] = transform
179
+ with torch.no_grad():
180
+ image = self._render_all_transformations(return_twice=False)
181
+ if log:
182
+ wandb.log({"image": wandb.Image(image)})
183
+ yield (image, image)
184
+ if log:
185
+ wandb.finish()
186
+ self.attn_mask = None
187
+ self.transform_history.append(transform_log)
188
+ # transform = self.prompt_optim.optimize(self.blend_latent,
189
+ # positive_prompts,
190
+ # negative_prompts)
191
+ # self.prompt_transforms = transform
192
+ # return self._render_all_transformations()
animation.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio
2
+ import glob
3
+ import os
4
+
5
+ def clear_img_dir():
6
+ img_dir = "./img_history"
7
+ if not os.path.exists(img_dir):
8
+ os.mkdir(img_dir)
9
+ for filename in glob.glob(img_dir+"/*"):
10
+ os.remove(filename)
11
+
12
+
13
+ def create_gif(total_duration, extend_frames, folder="./img_history", gif_name="face_edit.gif"):
14
+ images = []
15
+ paths = glob.glob(folder + "/*")
16
+ frame_duration = total_duration / len(paths)
17
+ print(len(paths), "frame dur", frame_duration)
18
+ durations = [frame_duration] * len(paths)
19
+ if extend_frames:
20
+ durations [0] = 1.5
21
+ durations [-1] = 3
22
+ for file_name in os.listdir(folder):
23
+ if file_name.endswith('.png'):
24
+ file_path = os.path.join(folder, file_name)
25
+ images.append(imageio.imread(file_path))
26
+ # images[0] = images[0].set_meta_data({'duration': 1})
27
+ # images[-1] = images[-1].set_meta_data({'duration': 1})
28
+ imageio.mimsave(gif_name, images, duration=durations)
29
+ return gif_name
30
+
31
+ if __name__ == "__main__":
32
+ # clear_img_dir()
33
+ create_gif()
34
+ # make_animation()
app.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import sys
4
+
5
+ import wandb
6
+
7
+ from configs import set_major_global, set_major_local, set_small_local
8
+
9
+ sys.path.append("taming-transformers")
10
+ import functools
11
+
12
+ import gradio as gr
13
+ from transformers import CLIPModel, CLIPProcessor
14
+
15
+ import edit
16
+ # import importlib
17
+ # importlib.reload(edit)
18
+ from app_backend import ImagePromptOptimizer, ProcessorGradientFlow
19
+ from ImageState import ImageState
20
+ from loaders import load_default
21
+ from animation import create_gif
22
+ from prompts import get_random_prompts
23
+
24
+ device = "cuda"
25
+ vqgan = load_default(device)
26
+ vqgan.eval()
27
+ processor = ProcessorGradientFlow(device=device)
28
+ clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
29
+ clip.to(device)
30
+ promptoptim = ImagePromptOptimizer(vqgan, clip, processor, quantize=True)
31
+ state = ImageState(vqgan, promptoptim)
32
+ def set_img_from_example(img):
33
+ return state.update_images(img, img, 0)
34
+ def get_cleared_mask():
35
+ return gr.Image.update(value=None)
36
+ # mask.clear()
37
+ with gr.Blocks(css="styles.css") as demo:
38
+ with gr.Row():
39
+ with gr.Column(scale=1):
40
+ blue_eyes = gr.Slider(
41
+ label="Blue Eyes",
42
+ minimum=-.8,
43
+ maximum=3.,
44
+ value=0,
45
+ step=0.1,
46
+ )
47
+ # hair_green_purple = gr.Slider(
48
+ # label="hair green<->purple ",
49
+ # minimum=-.8,
50
+ # maximum=.8,
51
+ # value=0,
52
+ # step=0.1,
53
+ # )
54
+ lip_size = gr.Slider(
55
+ label="Lip Size",
56
+ minimum=-1.9,
57
+ value=0,
58
+ maximum=1.9,
59
+ step=0.1,
60
+ )
61
+ blend_weight = gr.Slider(
62
+ label="Blend faces: 0 is base image, 1 is the second img",
63
+ minimum=-0.,
64
+ value=0,
65
+ maximum=1.,
66
+ step=0.1,
67
+ )
68
+ # requantize = gr.Checkbox(
69
+ # label="Requantize Latents (necessary using text prompts)",
70
+ # value=True,
71
+ # )
72
+ asian_weight = gr.Slider(
73
+ minimum=-2.,
74
+ value=0,
75
+ label="Asian",
76
+ maximum=2.,
77
+ step=0.07,
78
+ )
79
+ with gr.Row():
80
+ with gr.Column():
81
+ gr.Markdown(value="""## Image Upload
82
+ For best results, crop the photos like in the example pictures""", show_label=False)
83
+ with gr.Row():
84
+ base_img = gr.Image(label="Base Image", type="filepath")
85
+ blend_img = gr.Image(label="Image for face blending (optional)", type="filepath")
86
+ # gr.Markdown("## Image Examples")
87
+ with gr.Accordion(label="Add Mask", open=False):
88
+ mask = gr.Image(tool="sketch", interactive=True)
89
+ gr.Markdown(value="Note: You must clear the mask using the rewind button every time you want to change the mask (this is a gradio bug)")
90
+ set_mask = gr.Button(value="Set mask")
91
+ gr.Text(value="this image shows the mask passed to the model when you press set mask (debugging purposes)")
92
+ testim = gr.Image()
93
+ clear_mask = gr.Button(value="Clear mask")
94
+ clear_mask.click(get_cleared_mask, outputs=mask)
95
+ with gr.Row():
96
+ gr.Examples(
97
+ examples=glob.glob("test_pics/*"),
98
+ inputs=base_img,
99
+ outputs=blend_img,
100
+ fn=set_img_from_example,
101
+ # cache_examples=True,
102
+ )
103
+ with gr.Column(scale=1):
104
+ out = gr.Image()
105
+ rewind = gr.Slider(value=100,
106
+ label="Rewind back through a prompt transform: Use this to scroll through the iterations of your prompt transformation.",
107
+ minimum=0,
108
+ maximum=100)
109
+
110
+ apply_prompts = gr.Button(value="Apply Prompts", elem_id="apply")
111
+ clear = gr.Button(value="Clear all transformations (irreversible)", elem_id="warning")
112
+ with gr.Accordion(label="Save Animation", open=False):
113
+ gr.Text(value="Creates an animation of all the steps in the editing process", show_label=False)
114
+ duration = gr.Number(value=10, label="Duration of the animation in seconds")
115
+ extend_frames = gr.Checkbox(value=True, label="Make first and last frame longer")
116
+ gif = gr.File(interactive=False)
117
+ create_animation = gr.Button(value="Create Animation")
118
+ create_animation.click(create_gif, inputs=[duration, extend_frames], outputs=gif)
119
+
120
+ with gr.Column(scale=1):
121
+ gr.Markdown(value="""## Text Prompting
122
+ See readme for a prompting guide. Use the '|' symbol to separate prompts. Use the "Add mask" section to make local edits. Negative prompts are highly recommended""", show_label=False)
123
+ positive_prompts = gr.Textbox(label="Positive prompts",
124
+ value="a picture of a woman with a very big nose | a picture of a woman with a large wide nose | a woman with an extremely prominent nose")
125
+ negative_prompts = gr.Textbox(label="Negative prompts",
126
+ value="a picture of a person with a tiny nose | a picture of a person with a very thin nose")
127
+ gen_prompts = gr.Button(value="🎲 Random prompts")
128
+ gen_prompts.click(get_random_prompts, outputs=[positive_prompts, negative_prompts])
129
+ with gr.Row():
130
+ with gr.Column():
131
+ gr.Text(value="Prompt Editing Configuration", show_label=False)
132
+ with gr.Row():
133
+ gr.Markdown(value="## Preset Configs", show_label=False)
134
+ with gr.Row():
135
+ with gr.Column():
136
+ small_local = gr.Button(value="Small Masked Changes (e.g. add lipstick)", elem_id="small_local").style(full_width=False)
137
+ with gr.Column():
138
+ major_local = gr.Button(value="Major Masked Changes (e.g. change hair color or nose size)").style(full_width=False)
139
+ with gr.Column():
140
+ major_global = gr.Button(value="Major Global Changes (e.g. change race / gender").style(full_width=False)
141
+ iterations = gr.Slider(minimum=10,
142
+ maximum=300,
143
+ step=1,
144
+ value=20,
145
+ label="Iterations: How many steps the model will take to modify the image. Try starting small and seeing how the results turn out, you can always resume with afterwards",)
146
+ learning_rate = gr.Slider(minimum=1e-3,
147
+ maximum=6e-1,
148
+ value=1e-2,
149
+ label="Learning Rate: How strong the change in each step will be (you should raise this for bigger changes (for example, changing hair color), and lower it for more minor changes. Raise if changes aren't strong enough")
150
+ with gr.Accordion(label="Advanced Prompt Editing Options", open=False):
151
+ lpips_weight = gr.Slider(minimum=0,
152
+ maximum=50,
153
+ value=1,
154
+ label="Perceptual similarity weight (Keeps areas outside of the mask looking similar to the original. Increase if the rest of the image is changing too much while you're trying to change make a localized edit")
155
+ reconstruction_steps = gr.Slider(minimum=0,
156
+ maximum=50,
157
+ value=15,
158
+ step=1,
159
+ label="Steps to run at the end of the optimization, optimizing only the masked perceptual loss. If the edit is changing the identity too much, this setting will run steps at the end that will 'pull' the image back towards the original identity")
160
+ # discriminator_steps = gr.Slider(minimum=0,
161
+ # maximum=50,
162
+ # step=1,
163
+ # value=0,
164
+ # label="Steps to run at the end, optimizing only the discriminator loss. This helps to reduce artefacts, but because the model is trained on CelebA, this will make your generations look more like generic white celebrities")
165
+ clear.click(state.clear_transforms, outputs=[out, mask])
166
+ asian_weight.change(state.apply_gender_vector, inputs=[asian_weight], outputs=[out, mask])
167
+ lip_size.change(state.apply_lip_vector, inputs=[lip_size], outputs=[out, mask])
168
+ # hair_green_purple.change(state.apply_gp_vector, inputs=[hair_green_purple], outputs=[out, mask])
169
+ blue_eyes.change(state.apply_rb_vector, inputs=[blue_eyes], outputs=[out, mask])
170
+
171
+ blend_weight.change(state.blend, inputs=[blend_weight], outputs=[out, mask])
172
+ # requantize.change(state.update_requant, inputs=[requantize], outputs=[out, mask])
173
+
174
+
175
+ base_img.change(state.update_images, inputs=[base_img, blend_img, blend_weight], outputs=[out, mask])
176
+ blend_img.change(state.update_images, inputs=[base_img, blend_img, blend_weight], outputs=[out, mask])
177
+
178
+ small_local.click(set_small_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
179
+ major_local.click(set_major_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
180
+ small_local.click(set_major_global, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
181
+ apply_prompts.click(state.apply_prompts, inputs=[positive_prompts, negative_prompts, learning_rate, iterations, lpips_weight, reconstruction_steps], outputs=[out, mask])
182
+ rewind.change(state.rewind, inputs=[rewind], outputs=[out, mask])
183
+ set_mask.click(state.set_mask, inputs=mask, outputs=testim)
184
+ demo.queue()
185
+ demo.launch(debug=True, inbrowser=True)
186
+ # if __name__ == "__main__":
187
+ # import argparse
188
+ # parser = argparse.ArgumentParser()
189
+ # parser.add_argument('--debug', action='store_true', default=False, help='Enable debugging output')
190
+ # args = parser.parse_args()
191
+ # # if args.debug:
192
+ # # state=None
193
+ # # promptoptim=None
194
+ # # else:
195
+ # main()
app_backend.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from functools import cache
2
+ import importlib
3
+
4
+ import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ import torch
7
+ import torchvision
8
+ import wandb
9
+ from icecream import ic
10
+ from torch import nn
11
+ from torchvision.transforms.functional import resize
12
+ from tqdm import tqdm
13
+ from transformers import CLIPModel, CLIPProcessor
14
+ import lpips
15
+ from edit import blend_paths
16
+ from img_processing import *
17
+ from img_processing import custom_to_pil
18
+ from loaders import load_default
19
+ import glob
20
+ # global log
21
+ log=False
22
+
23
+ # ic.disable()
24
+ # ic.enable()
25
+ def get_resized_tensor(x):
26
+ if len(x.shape) == 2:
27
+ re = x.unsqueeze(0)
28
+ else: re = x
29
+ re = resize(re, (10, 10))
30
+ return re
31
+ class ProcessorGradientFlow():
32
+ """
33
+ This wraps the huggingface CLIP processor to allow backprop through the image processing step.
34
+ The original processor forces conversion to PIL images, which breaks gradient flow.
35
+ """
36
+ def __init__(self, device="cuda") -> None:
37
+ self.device = device
38
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
39
+ self.image_mean = [0.48145466, 0.4578275, 0.40821073]
40
+ self.image_std = [0.26862954, 0.26130258, 0.27577711]
41
+ self.normalize = torchvision.transforms.Normalize(
42
+ self.image_mean,
43
+ self.image_std
44
+ )
45
+ self.resize = torchvision.transforms.Resize(224)
46
+ self.center_crop = torchvision.transforms.CenterCrop(224)
47
+ def preprocess_img(self, images):
48
+ images = self.center_crop(images)
49
+ images = self.resize(images)
50
+ images = self.center_crop(images)
51
+ images = self.normalize(images)
52
+ return images
53
+ def __call__(self, images=[], **kwargs):
54
+ processed_inputs = self.processor(**kwargs)
55
+ processed_inputs["pixel_values"] = self.preprocess_img(images)
56
+ processed_inputs = {key:value.to(self.device) for (key, value) in processed_inputs.items()}
57
+ return processed_inputs
58
+
59
+ class ImagePromptOptimizer(nn.Module):
60
+ def __init__(self,
61
+ vqgan,
62
+ clip,
63
+ clip_preprocessor,
64
+ iterations=100,
65
+ lr = 0.01,
66
+ save_vector=True,
67
+ return_val="vector",
68
+ quantize=True,
69
+ make_grid=False,
70
+ lpips_weight = 6.2) -> None:
71
+
72
+ super().__init__()
73
+ self.latent = None
74
+ self.device = vqgan.device
75
+ vqgan.eval()
76
+ self.vqgan = vqgan
77
+ self.clip = clip
78
+ self.iterations = iterations
79
+ self.lr = lr
80
+ self.clip_preprocessor = clip_preprocessor
81
+ self.make_grid = make_grid
82
+ self.return_val = return_val
83
+ self.quantize = quantize
84
+ self.disc = load_disc(self.device)
85
+ self.lpips_weight = lpips_weight
86
+ self.perceptual_loss = lpips.LPIPS(net='vgg').to(self.device)
87
+ def disc_loss_fn(self, logits):
88
+ return -torch.mean(logits)
89
+ def set_latent(self, latent):
90
+ self.latent = latent.detach().to(self.device)
91
+ def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
92
+ self.attn_mask = attn_mask
93
+ self.iterations = iterations
94
+ self.lr = lr
95
+ self.lpips_weight = lpips_weight
96
+ self.reconstruction_steps = reconstruction_steps
97
+ def forward(self, vector):
98
+ base_latent = self.latent.detach().requires_grad_()
99
+ trans_latent = base_latent + vector
100
+ if self.quantize:
101
+ z_q, *_ = self.vqgan.quantize(trans_latent)
102
+ else:
103
+ z_q = trans_latent
104
+ dec = self.vqgan.decode(z_q)
105
+ return dec
106
+ def _get_clip_similarity(self, prompts, image, weights=None):
107
+ if isinstance(prompts, str):
108
+ prompts = [prompts]
109
+ elif not isinstance(prompts, list):
110
+ raise TypeError("Provide prompts as string or list of strings")
111
+ clip_inputs = self.clip_preprocessor(text=prompts,
112
+ images=image, return_tensors="pt", padding=True)
113
+ clip_outputs = self.clip(**clip_inputs)
114
+ similarity_logits = clip_outputs.logits_per_image
115
+ if weights:
116
+ similarity_logits *= weights
117
+ return similarity_logits.sum()
118
+ def get_similarity_loss(self, pos_prompts, neg_prompts, image):
119
+ pos_logits = self._get_clip_similarity(pos_prompts, image)
120
+ if neg_prompts:
121
+ neg_logits = self._get_clip_similarity(neg_prompts, image)
122
+ else:
123
+ neg_logits = torch.tensor([1], device=self.device)
124
+ loss = -torch.log(pos_logits) + torch.log(neg_logits)
125
+ return loss
126
+ def visualize(self, processed_img):
127
+ if self.make_grid:
128
+ self.index += 1
129
+ plt.subplot(1, 13, self.index)
130
+ plt.imshow(get_pil(processed_img[0]).detach().cpu())
131
+ else:
132
+ plt.imshow(get_pil(processed_img[0]).detach().cpu())
133
+ plt.show()
134
+ def attn_masking(self, grad):
135
+ # print("attnmask 1")
136
+ # print(f"input grad.shape = {grad.shape}")
137
+ # print(f"input grad = {get_resized_tensor(grad)}")
138
+ newgrad = grad
139
+ if self.attn_mask is not None:
140
+ # print("masking mult")
141
+ newgrad = grad * (self.attn_mask)
142
+ # print("output grad, ", get_resized_tensor(newgrad))
143
+ # print("end atn 1")
144
+ return newgrad
145
+ def attn_masking2(self, grad):
146
+ # print("attnmask 2")
147
+ # print(f"input grad.shape = {grad.shape}")
148
+ # print(f"input grad = {get_resized_tensor(grad)}")
149
+ newgrad = grad
150
+ if self.attn_mask is not None:
151
+ # print("masking mult")
152
+ newgrad = grad * ((self.attn_mask - 1) * -1)
153
+ # print("output grad, ", get_resized_tensor(newgrad))
154
+ # print("end atn 2")
155
+ return newgrad
156
+
157
+ def optimize(self, latent, pos_prompts, neg_prompts):
158
+ self.set_latent(latent)
159
+ # self.make_grid=True
160
+ transformed_img = self(torch.zeros_like(self.latent, requires_grad=True, device=self.device))
161
+ original_img = loop_post_process(transformed_img)
162
+ vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
163
+ optim = torch.optim.Adam([vector], lr=self.lr)
164
+ if self.make_grid:
165
+ plt.figure(figsize=(35, 25))
166
+ self.index = 1
167
+ for i in tqdm(range(self.iterations)):
168
+ optim.zero_grad()
169
+ transformed_img = self(vector)
170
+ processed_img = loop_post_process(transformed_img) #* self.attn_mask
171
+ processed_img.retain_grad()
172
+ lpips_input = processed_img.clone()
173
+ lpips_input.register_hook(self.attn_masking2)
174
+ lpips_input.retain_grad()
175
+ clip_clone = processed_img.clone()
176
+ clip_clone.register_hook(self.attn_masking)
177
+ clip_clone.retain_grad()
178
+ with torch.autocast("cuda"):
179
+ clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts, clip_clone)
180
+ print("CLIP loss", clip_loss)
181
+ perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
182
+ print("LPIPS loss: ", perceptual_loss)
183
+ with torch.no_grad():
184
+ disc_logits = self.disc(transformed_img)
185
+ disc_loss = self.disc_loss_fn(disc_logits)
186
+ print(f"disc_loss = {disc_loss}")
187
+ disc_loss2 = self.disc(processed_img)
188
+ if log:
189
+ wandb.log({"Perceptual Loss": perceptual_loss})
190
+ wandb.log({"Discriminator Loss": disc_loss})
191
+ wandb.log({"CLIP Loss": clip_loss})
192
+ clip_loss.backward(retain_graph=True)
193
+ perceptual_loss.backward(retain_graph=True)
194
+ p2 = processed_img.grad
195
+ print("Sum Loss", perceptual_loss + clip_loss)
196
+ optim.step()
197
+ # if i % self.iterations // 10 == 0:
198
+ # self.visualize(transformed_img)
199
+ yield vector
200
+ if self.make_grid:
201
+ plt.savefig(f"plot {pos_prompts[0]}.png")
202
+ plt.show()
203
+ print("lpips solo op")
204
+ for i in range(self.reconstruction_steps):
205
+ optim.zero_grad()
206
+ transformed_img = self(vector)
207
+ processed_img = loop_post_process(transformed_img) #* self.attn_mask
208
+ processed_img.retain_grad()
209
+ lpips_input = processed_img.clone()
210
+ lpips_input.register_hook(self.attn_masking2)
211
+ lpips_input.retain_grad()
212
+ with torch.autocast("cuda"):
213
+ perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
214
+ with torch.no_grad():
215
+ disc_logits = self.disc(transformed_img)
216
+ disc_loss = self.disc_loss_fn(disc_logits)
217
+ print(f"disc_loss = {disc_loss}")
218
+ disc_loss2 = self.disc(processed_img)
219
+ # print(f"disc_loss2 = {disc_loss2}")
220
+ if log:
221
+ wandb.log({"Perceptual Loss": perceptual_loss})
222
+ print("LPIPS loss: ", perceptual_loss)
223
+ perceptual_loss.backward(retain_graph=True)
224
+ optim.step()
225
+ yield vector
226
+ # torch.save(vector, "nose_vector.pt")
227
+ # print("")
228
+ # print("DISC STEPS")
229
+ # print("*************")
230
+ # for i in range(self.reconstruction_steps):
231
+ # optim.zero_grad()
232
+ # transformed_img = self(vector)
233
+ # processed_img = loop_post_process(transformed_img) #* self.attn_mask
234
+ # disc_logits = self.disc(transformed_img)
235
+ # disc_loss = self.disc_loss_fn(disc_logits)
236
+ # print(f"disc_loss = {disc_loss}")
237
+ # if log:
238
+ # wandb.log({"Disc Loss": disc_loss})
239
+ # print("LPIPS loss: ", perceptual_loss)
240
+ # disc_loss.backward(retain_graph=True)
241
+ # optim.step()
242
+ # yield vector
243
+ yield vector if self.return_val == "vector" else self.latent + vector
configs.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ def set_small_local():
3
+ return (gr.Slider.update(value=25), gr.Slider.update(value=0.15), gr.Slider.update(value=1), gr.Slider.update(value=4))
4
+ def set_major_local():
5
+ return (gr.Slider.update(value=25), gr.Slider.update(value=0.25), gr.Slider.update(value=35), gr.Slider.update(value=10))
6
+ def set_major_global():
7
+ return (gr.Slider.update(value=30), gr.Slider.update(value=0.1), gr.Slider.update(value=2), gr.Slider.update(value=0.2))
edit.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ from img_processing import custom_to_pil, preprocess, preprocess_vqgan
5
+
6
+ sys.path.append("taming-transformers")
7
+ import glob
8
+
9
+ import gradio as gr
10
+ import matplotlib.pyplot as plt
11
+ import PIL
12
+ import taming
13
+ import torch
14
+
15
+ from loaders import load_config
16
+ from utils import get_device
17
+
18
+
19
+ def get_embedding(model, path=None, img=None, device="cpu"):
20
+ assert path is None or img is None, "Input either path or tensor"
21
+ if img is not None:
22
+ raise NotImplementedError
23
+ x = preprocess(PIL.Image.open(path), target_image_size=256).to(device)
24
+ x_processed = preprocess_vqgan(x)
25
+ x_latent, _, [_, _, indices] = model.encode(x_processed)
26
+ return x_latent
27
+
28
+
29
+ def blend_paths(model, path1, path2, quantize=False, weight=0.5, show=True, device="cuda"):
30
+ x = preprocess(PIL.Image.open(path1), target_image_size=256).to(device)
31
+ y = preprocess(PIL.Image.open(path2), target_image_size=256).to(device)
32
+ x_latent, y_latent = get_embedding(model, path=path1, device=device), get_embedding(model, path=path2, device=device)
33
+ z = torch.lerp(x_latent, y_latent, weight)
34
+ if quantize:
35
+ z = model.quantize(z)[0]
36
+ decoded = model.decode(z)[0]
37
+ if show:
38
+ plt.figure(figsize=(10, 20))
39
+ plt.subplot(1, 3, 1)
40
+ plt.imshow(x.cpu().permute(0, 2, 3, 1)[0])
41
+ plt.subplot(1, 3, 2)
42
+ plt.imshow(custom_to_pil(decoded))
43
+ plt.subplot(1, 3, 3)
44
+ plt.imshow(y.cpu().permute(0, 2, 3, 1)[0])
45
+ plt.show()
46
+ return custom_to_pil(decoded), z
47
+
48
+ if __name__ == "__main__":
49
+ device = get_device()
50
+ # conf_path = "logs/2021-04-23T18-11-19_celebahq_transformer/configs/2021-04-23T18-11-19-project.yaml"
51
+ ckpt_path = "logs/2021-04-23T18-11-19_celebahq_transformer/checkpoints/last.ckpt"
52
+ # ckpt_path = "./faceshq/faceshq.pt"
53
+ conf_path = "./unwrapped.yaml"
54
+ # conf_path = "./faceshq/faceshq.yaml"
55
+ config = load_config(conf_path, display=False)
56
+ model = taming.models.vqgan.VQModel(**config.model.params)
57
+ sd = torch.load("./vqgan_only.pt", map_location="mps")
58
+ model.load_state_dict(sd, strict=True)
59
+ model.to(device)
60
+ blend_paths(model, "./test_data/face.jpeg", "./test_data/face2.jpeg", quantize=False, weight=.5)
61
+ plt.show()
62
+
63
+ demo = gr.Interface(
64
+ get_image,
65
+ inputs=gr.inputs.Image(label="UploadZz a black and white face", type="filepath"),
66
+ outputs="image",
67
+ title="Upload a black and white face and get a colorized image!",
68
+ )
69
+
img_processing.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import sys
4
+
5
+ import numpy as np
6
+ import PIL
7
+ import requests
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torchvision.transforms as T
11
+ import torchvision.transforms.functional as TF
12
+ from PIL import Image, ImageDraw, ImageFont
13
+
14
+
15
+ def download_image(url):
16
+ resp = requests.get(url)
17
+ resp.raise_for_status()
18
+ return PIL.Image.open(io.BytesIO(resp.content))
19
+
20
+
21
+ def preprocess(img, target_image_size=256, map_dalle=False):
22
+ s = min(img.size)
23
+
24
+ if s < target_image_size:
25
+ raise ValueError(f'min dim for image {s} < {target_image_size}')
26
+
27
+ r = target_image_size / s
28
+ s = (round(r * img.size[1]), round(r * img.size[0]))
29
+ img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
30
+ img = TF.center_crop(img, output_size=2 * [target_image_size])
31
+ img = torch.unsqueeze(T.ToTensor()(img), 0)
32
+ return img
33
+
34
+ def preprocess_vqgan(x):
35
+ x = 2.*x - 1.
36
+ return x
37
+
38
+ def custom_to_pil(x, process=True, mode="RGB"):
39
+ x = x.detach().cpu()
40
+ if process:
41
+ x = torch.clamp(x, -1., 1.)
42
+ x = (x + 1.)/2.
43
+ x = x.permute(1,2,0).numpy()
44
+ if process:
45
+ x = (255*x).astype(np.uint8)
46
+ x = Image.fromarray(x)
47
+ if not x.mode == mode:
48
+ x = x.convert(mode)
49
+ return x
50
+
51
+ def get_pil(x):
52
+ x = torch.clamp(x, -1., 1.)
53
+ x = (x + 1.)/2.
54
+ x = x.permute(1,2,0)
55
+ return x
56
+
57
+ def loop_post_process(x):
58
+ x = get_pil(x.squeeze())
59
+ return x.permute(2, 0, 1).unsqueeze(0)
60
+
61
+ def stack_reconstructions(input, x0, x1, x2, x3, titles=[]):
62
+ assert input.size == x1.size == x2.size == x3.size
63
+ w, h = input.size[0], input.size[1]
64
+ img = Image.new("RGB", (5*w, h))
65
+ img.paste(input, (0,0))
66
+ img.paste(x0, (1*w,0))
67
+ img.paste(x1, (2*w,0))
68
+ img.paste(x2, (3*w,0))
69
+ img.paste(x3, (4*w,0))
70
+ for i, title in enumerate(titles):
71
+ ImageDraw.Draw(img).text((i*w, 0), f'{title}', (255, 255, 255), font=font) # coordinates, text, color, font
72
+ return img
loaders.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ import numpy as np
4
+ import taming
5
+ import torch
6
+ import yaml
7
+ from omegaconf import OmegaConf
8
+ from PIL import Image
9
+ from taming.models.vqgan import VQModel
10
+
11
+ from utils import get_device
12
+ # import discriminator
13
+
14
+ def load_config(config_path, display=False):
15
+ config = OmegaConf.load(config_path)
16
+ if display:
17
+ print(yaml.dump(OmegaConf.to_container(config)))
18
+ return config
19
+
20
+ # def load_disc(device):
21
+ # dconf = load_config("disc_config.yaml")
22
+ # sd = torch.load("disc.pt", map_location=device)
23
+ # # print(sd.keys())
24
+ # model = discriminator.NLayerDiscriminator()
25
+ # model.load_state_dict(sd, strict=True)
26
+ # model.to(device)
27
+ # return model
28
+ # print(dconf.keys())
29
+
30
+ def load_default(device):
31
+ # device = get_device()
32
+ ckpt_path = "logs/2021-04-23T18-11-19_celebahq_transformer/checkpoints/last.ckpt"
33
+ conf_path = "./unwrapped.yaml"
34
+ config = load_config(conf_path, display=False)
35
+ model = taming.models.vqgan.VQModel(**config.model.params)
36
+ sd = torch.load("./vqgan_only.pt", map_location=device)
37
+ model.load_state_dict(sd, strict=True)
38
+ model.to(device)
39
+ return model
40
+
41
+
42
+ def load_vqgan(config, ckpt_path=None, is_gumbel=False):
43
+ if is_gumbel:
44
+ model = GumbelVQ(**config.model.params)
45
+ else:
46
+ model = VQModel(**config.model.params)
47
+ if ckpt_path is not None:
48
+ sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
49
+ missing, unexpected = model.load_state_dict(sd, strict=False)
50
+ return model.eval()
51
+
52
+ def load_ffhq():
53
+ conf = "2020-11-09T13-33-36_faceshq_vqgan/configs/2020-11-09T13-33-36-project.yaml"
54
+ ckpt = "2020-11-09T13-33-36_faceshq_vqgan/checkpoints/last.ckpt"
55
+ vqgan = load_model(load_config(conf), ckpt, True, True)[0]
56
+
57
+ def reconstruct_with_vqgan(x, model):
58
+ # could also use model(x) for reconstruction but use explicit encoding and decoding here
59
+ z, _, [_, _, indices] = model.encode(x)
60
+ print(f"VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}")
61
+ xrec = model.decode(z)
62
+ return xrec
63
+ def get_obj_from_str(string, reload=False):
64
+ module, cls = string.rsplit(".", 1)
65
+ if reload:
66
+ module_imp = importlib.import_module(module)
67
+ importlib.reload(module_imp)
68
+ return getattr(importlib.import_module(module, package=None), cls)
69
+
70
+ def instantiate_from_config(config):
71
+
72
+ if not "target" in config:
73
+ raise KeyError("Expected key `target` to instantiate.")
74
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
75
+
76
+ def load_model_from_config(config, sd, gpu=True, eval_mode=True):
77
+ model = instantiate_from_config(config)
78
+ if sd is not None:
79
+ model.load_state_dict(sd)
80
+ if gpu:
81
+ model.cuda()
82
+ if eval_mode:
83
+ model.eval()
84
+ return {"model": model}
85
+
86
+
87
+ def load_model(config, ckpt, gpu, eval_mode):
88
+ # load the specified checkpoint
89
+ if ckpt:
90
+ pl_sd = torch.load(ckpt, map_location="cpu")
91
+ global_step = pl_sd["global_step"]
92
+ print(f"loaded model from global step {global_step}.")
93
+ else:
94
+ pl_sd = {"state_dict": None}
95
+ global_step = None
96
+ model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"]
97
+ return model, global_step
masking.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import matplotlib.pyplot as plt
5
+ import torch
6
+
7
+ sys.path.append("taming-transformers")
8
+ import functools
9
+
10
+ import gradio as gr
11
+ from transformers import CLIPModel, CLIPProcessor
12
+
13
+ import edit
14
+ # import importlib
15
+ # importlib.reload(edit)
16
+ from app_backend import ImagePromptOptimizer, ImageState, ProcessorGradientFlow
17
+ from loaders import load_default
18
+
19
+ device = "cuda"
20
+ vqgan = load_default(device)
21
+ vqgan.eval()
22
+ processor = ProcessorGradientFlow(device=device)
23
+ clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
24
+ clip.to(device)
25
+ promptoptim = ImagePromptOptimizer(vqgan, clip, processor, quantize=True)
26
+ state = ImageState(vqgan, promptoptim)
27
+ mask = torch.load("eyebrow_mask.pt")
28
+ x = state.blend("./test_data/face.jpeg", "./test_data/face2.jpeg", 0.5)
29
+ plt.imshow(x)
30
+ plt.show()
31
+ state.apply_prompts("a picture of a woman with big eyebrows", "", 0.009, 40, None, mask=mask)
32
+ print('done')
prompts.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ class PromptSet:
3
+ def __init__(self, pos, neg, config=None):
4
+ self.positive = pos
5
+ self.negative = neg
6
+ self.config = config
7
+ example_prompts = (
8
+ PromptSet("a picture of a woman with light blonde hair", "a picture of a person with dark hair | a picture of a person with brown hair"),
9
+ PromptSet("A picture of a woman with very thick eyebrows", "a picture of a person with very thin eyebrows | a picture of a person with no eyebrows"),
10
+ PromptSet("A picture of a woman wearing bright red lipstick", "a picture of a person wearing no lipstick | a picture of a person wearing dark lipstick"),
11
+ PromptSet("A picture of a beautiful chinese woman | a picture of a Japanese woman | a picture of an Asian woman", "a picture of a white woman | a picture of an Indian woman | a picture of a black woman"),
12
+ PromptSet("A picture of a handsome man | a picture of a masculine man", "a picture of a woman | a picture of a feminine person"),
13
+ PromptSet("A picture of a woman with a very big nose", "a picture of a person with a small nose | a picture of a person with a normal nose"),
14
+ )
15
+ def get_random_prompts():
16
+ prompt = random.choice(example_prompts)
17
+ return prompt.positive, prompt.negative
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ taming-transformers
2
+ einops
3
+ gradio
4
+ icecream
5
+ imageio
6
+ matplotlib
7
+ more_itertools
8
+ numpy
9
+ omegaconf
10
+ opencv_python_headless
11
+ Pillow
12
+ prompts
13
+ pudb
14
+ pytorch_lightning
15
+ PyYAML
16
+ requests
17
+ scikit_image
18
+ scipy
19
+ setuptools
20
+ streamlit
21
+ torch
22
+ torchvision
23
+ tqdm
24
+ transformers
25
+ typing_extensions
26
+ wandb
27
+ lpips
unwrapped.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: taming.models.vqgan.VQModel
3
+ params:
4
+ embed_dim: 256
5
+ n_embed: 1024
6
+ ddconfig:
7
+ double_z: false
8
+ z_channels: 256
9
+ resolution: 256
10
+ in_channels: 3
11
+ out_ch: 3
12
+ ch: 128
13
+ ch_mult:
14
+ - 1
15
+ - 1
16
+ - 2
17
+ - 2
18
+ - 4
19
+ num_res_blocks: 2
20
+ attn_resolutions:
21
+ - 16
22
+ dropout: 0.0
23
+ lossconfig:
24
+ target: taming.modules.losses.vqperceptual.DummyLoss
25
+ data:
26
+ target: cutlit.DataModuleFromConfig
27
+ params:
28
+ batch_size: 24
29
+ num_workers: 24
30
+ train:
31
+ target: taming.data.faceshq.CelebAHQTrain
32
+ params:
33
+ size: 256
34
+ validation:
35
+ target: taming.data.faceshq.CelebAHQValidation
36
+ params:
37
+ size: 256
utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from skimage.color import lab2rgb, rgb2lab
8
+ from torch import nn
9
+
10
+
11
+ def freeze_module(module):
12
+ for param in module.parameters():
13
+ param.requires_grad = False
14
+ def get_device():
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built():
17
+ device = "mps"
18
+ return (device)
vqgan_latent_ops.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from gradient_flow_ops import ReplaceGrad
6
+
7
+ replace_grad = ReplaceGrad.apply
8
+
9
+ def vector_quantize(x, codebook):
10
+
11
+ d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
12
+ indices = d.argmin(-1)
13
+ x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
14
+ return replace_grad(x_q, x)