Spaces:
Configuration error
Configuration error
Erwann Millon
commited on
Commit
·
ec39fe8
1
Parent(s):
006354e
refactoring and cleanup
Browse files- ImageState.py +8 -29
- animation.py +3 -0
- app.py +0 -12
- backend.py +1 -1
ImageState.py
CHANGED
|
@@ -31,7 +31,6 @@ class PromptTransformHistory():
|
|
| 31 |
|
| 32 |
class ImageState:
|
| 33 |
def __init__(self, vqgan, prompt_optimizer: ImagePromptOptimizer) -> None:
|
| 34 |
-
# global vqgan
|
| 35 |
self.vqgan = vqgan
|
| 36 |
self.device = vqgan.device
|
| 37 |
self.blend_latent = None
|
|
@@ -41,14 +40,11 @@ class ImageState:
|
|
| 41 |
self.transform_history = []
|
| 42 |
self.attn_mask = None
|
| 43 |
self.prompt_optim = prompt_optimizer
|
| 44 |
-
self.state_id = None
|
| 45 |
-
print(self.state_id)
|
| 46 |
self._load_vectors()
|
| 47 |
self.init_transforms()
|
| 48 |
def _load_vectors(self):
|
| 49 |
self.lip_vector = torch.load("./latent_vectors/lipvector.pt", map_location=self.device)
|
| 50 |
-
self.
|
| 51 |
-
self.green_purple_vector = torch.load("./latent_vectors/nose_vector.pt", map_location=self.device)
|
| 52 |
self.asian_vector = torch.load("./latent_vectors/asian10.pt", map_location=self.device)
|
| 53 |
def create_gif(self, total_duration, extend_frames, gif_name="face_edit.gif"):
|
| 54 |
images = []
|
|
@@ -71,7 +67,6 @@ class ImageState:
|
|
| 71 |
self.lip_size = torch.zeros_like(self.lip_vector)
|
| 72 |
self.asian_transform = torch.zeros_like(self.lip_vector)
|
| 73 |
self.current_prompt_transforms = [torch.zeros_like(self.lip_vector)]
|
| 74 |
-
self.hair_gp = torch.zeros_like(self.lip_vector)
|
| 75 |
def clear_transforms(self):
|
| 76 |
global num
|
| 77 |
self.init_transforms()
|
|
@@ -95,25 +90,22 @@ class ImageState:
|
|
| 95 |
attn_mask = mask
|
| 96 |
return attn_mask
|
| 97 |
def set_mask(self, img):
|
| 98 |
-
attn_mask = self._get_mask(img)
|
| 99 |
-
|
| 100 |
-
# attn_mask = torch.ones_like(img, device=self.device)
|
| 101 |
-
x = attn_mask.clone()
|
| 102 |
x = x.detach().cpu()
|
| 103 |
x = torch.clamp(x, -1., 1.)
|
| 104 |
x = (x + 1.)/2.
|
| 105 |
x = x.numpy()
|
| 106 |
-
x = (255*x).astype(np.uint8)
|
| 107 |
x = Image.fromarray(x, "L")
|
| 108 |
return x
|
| 109 |
@torch.no_grad()
|
| 110 |
def _render_all_transformations(self, return_twice=True):
|
| 111 |
global num
|
| 112 |
-
# global vqgan
|
| 113 |
if self.state_id is None:
|
| 114 |
self.state_id = "./img_history/" + str(uuid.uuid4())
|
| 115 |
print("redner all", self.state_id)
|
| 116 |
-
current_vector_transforms = (self.blue_eyes, self.lip_size, self.
|
| 117 |
new_latent = self.blend_latent + sum(current_vector_transforms)
|
| 118 |
if self.quant:
|
| 119 |
new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device))
|
|
@@ -126,17 +118,13 @@ class ImageState:
|
|
| 126 |
image.save(f"{img_dir}/img_{num:06}.png")
|
| 127 |
num += 1
|
| 128 |
return (image, image) if return_twice else image
|
| 129 |
-
def apply_gp_vector(self, weight):
|
| 130 |
-
self.hair_gp = weight * self.green_purple_vector
|
| 131 |
-
return self._render_all_transformations()
|
| 132 |
def apply_rb_vector(self, weight):
|
| 133 |
-
self.blue_eyes = weight * self.
|
| 134 |
return self._render_all_transformations()
|
| 135 |
def apply_lip_vector(self, weight):
|
| 136 |
self.lip_size = weight * self.lip_vector
|
| 137 |
return self._render_all_transformations()
|
| 138 |
-
def
|
| 139 |
-
print(f"val = {val}")
|
| 140 |
self.quant = val
|
| 141 |
return self._render_all_transformations()
|
| 142 |
def apply_asian_vector(self, weight):
|
|
@@ -144,11 +132,7 @@ class ImageState:
|
|
| 144 |
return self._render_all_transformations()
|
| 145 |
def update_images(self, path1, path2, blend_weight):
|
| 146 |
if path1 is None and path2 is None:
|
| 147 |
-
print("no paths")
|
| 148 |
return None
|
| 149 |
-
if path1 == path2:
|
| 150 |
-
print("paths are the same")
|
| 151 |
-
print(path1)
|
| 152 |
if path1 is None: path1 = path2
|
| 153 |
if path2 is None: path2 = path1
|
| 154 |
self.path1, self.path2 = path1, path2
|
|
@@ -203,9 +187,4 @@ class ImageState:
|
|
| 203 |
self.attn_mask = None
|
| 204 |
self.transform_history.append(transform_log)
|
| 205 |
gc.collect()
|
| 206 |
-
torch.cuda.empty_cache()
|
| 207 |
-
# transform = self.prompt_optim.optimize(self.blend_latent,
|
| 208 |
-
# positive_prompts,
|
| 209 |
-
# negative_prompts)
|
| 210 |
-
# self.prompt_transforms = transform
|
| 211 |
-
# return self._render_all_transformations()
|
|
|
|
| 31 |
|
| 32 |
class ImageState:
|
| 33 |
def __init__(self, vqgan, prompt_optimizer: ImagePromptOptimizer) -> None:
|
|
|
|
| 34 |
self.vqgan = vqgan
|
| 35 |
self.device = vqgan.device
|
| 36 |
self.blend_latent = None
|
|
|
|
| 40 |
self.transform_history = []
|
| 41 |
self.attn_mask = None
|
| 42 |
self.prompt_optim = prompt_optimizer
|
|
|
|
|
|
|
| 43 |
self._load_vectors()
|
| 44 |
self.init_transforms()
|
| 45 |
def _load_vectors(self):
|
| 46 |
self.lip_vector = torch.load("./latent_vectors/lipvector.pt", map_location=self.device)
|
| 47 |
+
self.blue_eyes_vector = torch.load("./latent_vectors/2blue_eyes.pt", map_location=self.device)
|
|
|
|
| 48 |
self.asian_vector = torch.load("./latent_vectors/asian10.pt", map_location=self.device)
|
| 49 |
def create_gif(self, total_duration, extend_frames, gif_name="face_edit.gif"):
|
| 50 |
images = []
|
|
|
|
| 67 |
self.lip_size = torch.zeros_like(self.lip_vector)
|
| 68 |
self.asian_transform = torch.zeros_like(self.lip_vector)
|
| 69 |
self.current_prompt_transforms = [torch.zeros_like(self.lip_vector)]
|
|
|
|
| 70 |
def clear_transforms(self):
|
| 71 |
global num
|
| 72 |
self.init_transforms()
|
|
|
|
| 90 |
attn_mask = mask
|
| 91 |
return attn_mask
|
| 92 |
def set_mask(self, img):
|
| 93 |
+
self.attn_mask = self._get_mask(img)
|
| 94 |
+
x = self.attn_mask.clone()
|
|
|
|
|
|
|
| 95 |
x = x.detach().cpu()
|
| 96 |
x = torch.clamp(x, -1., 1.)
|
| 97 |
x = (x + 1.)/2.
|
| 98 |
x = x.numpy()
|
| 99 |
+
x = (255 * x).astype(np.uint8)
|
| 100 |
x = Image.fromarray(x, "L")
|
| 101 |
return x
|
| 102 |
@torch.no_grad()
|
| 103 |
def _render_all_transformations(self, return_twice=True):
|
| 104 |
global num
|
|
|
|
| 105 |
if self.state_id is None:
|
| 106 |
self.state_id = "./img_history/" + str(uuid.uuid4())
|
| 107 |
print("redner all", self.state_id)
|
| 108 |
+
current_vector_transforms = (self.blue_eyes, self.lip_size, self.asian_transform, sum(self.current_prompt_transforms))
|
| 109 |
new_latent = self.blend_latent + sum(current_vector_transforms)
|
| 110 |
if self.quant:
|
| 111 |
new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device))
|
|
|
|
| 118 |
image.save(f"{img_dir}/img_{num:06}.png")
|
| 119 |
num += 1
|
| 120 |
return (image, image) if return_twice else image
|
|
|
|
|
|
|
|
|
|
| 121 |
def apply_rb_vector(self, weight):
|
| 122 |
+
self.blue_eyes = weight * self.blue_eyes_vector
|
| 123 |
return self._render_all_transformations()
|
| 124 |
def apply_lip_vector(self, weight):
|
| 125 |
self.lip_size = weight * self.lip_vector
|
| 126 |
return self._render_all_transformations()
|
| 127 |
+
def update_quant(self, val):
|
|
|
|
| 128 |
self.quant = val
|
| 129 |
return self._render_all_transformations()
|
| 130 |
def apply_asian_vector(self, weight):
|
|
|
|
| 132 |
return self._render_all_transformations()
|
| 133 |
def update_images(self, path1, path2, blend_weight):
|
| 134 |
if path1 is None and path2 is None:
|
|
|
|
| 135 |
return None
|
|
|
|
|
|
|
|
|
|
| 136 |
if path1 is None: path1 = path2
|
| 137 |
if path2 is None: path2 = path1
|
| 138 |
self.path1, self.path2 = path1, path2
|
|
|
|
| 187 |
self.attn_mask = None
|
| 188 |
self.transform_history.append(transform_log)
|
| 189 |
gc.collect()
|
| 190 |
+
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
animation.py
CHANGED
|
@@ -2,6 +2,7 @@ import imageio
|
|
| 2 |
import glob
|
| 3 |
import os
|
| 4 |
|
|
|
|
| 5 |
def clear_img_dir(img_dir):
|
| 6 |
if not os.path.exists("img_history"):
|
| 7 |
os.mkdir("img_history")
|
|
@@ -10,6 +11,7 @@ def clear_img_dir(img_dir):
|
|
| 10 |
for filename in glob.glob(img_dir+"/*"):
|
| 11 |
os.remove(filename)
|
| 12 |
|
|
|
|
| 13 |
def create_gif(total_duration, extend_frames, folder="./img_history", gif_name="face_edit.gif"):
|
| 14 |
images = []
|
| 15 |
paths = glob.glob(folder + "/*")
|
|
@@ -26,5 +28,6 @@ def create_gif(total_duration, extend_frames, folder="./img_history", gif_name="
|
|
| 26 |
imageio.mimsave(gif_name, images, duration=durations)
|
| 27 |
return gif_name
|
| 28 |
|
|
|
|
| 29 |
if __name__ == "__main__":
|
| 30 |
create_gif()
|
|
|
|
| 2 |
import glob
|
| 3 |
import os
|
| 4 |
|
| 5 |
+
|
| 6 |
def clear_img_dir(img_dir):
|
| 7 |
if not os.path.exists("img_history"):
|
| 8 |
os.mkdir("img_history")
|
|
|
|
| 11 |
for filename in glob.glob(img_dir+"/*"):
|
| 12 |
os.remove(filename)
|
| 13 |
|
| 14 |
+
|
| 15 |
def create_gif(total_duration, extend_frames, folder="./img_history", gif_name="face_edit.gif"):
|
| 16 |
images = []
|
| 17 |
paths = glob.glob(folder + "/*")
|
|
|
|
| 28 |
imageio.mimsave(gif_name, images, duration=durations)
|
| 29 |
return gif_name
|
| 30 |
|
| 31 |
+
|
| 32 |
if __name__ == "__main__":
|
| 33 |
create_gif()
|
app.py
CHANGED
|
@@ -100,7 +100,6 @@ with gr.Blocks(css="styles.css") as demo:
|
|
| 100 |
label="Rewind back through a prompt transform: Use this to scroll through the iterations of your prompt transformation.",
|
| 101 |
minimum=0,
|
| 102 |
maximum=100)
|
| 103 |
-
|
| 104 |
apply_prompts = gr.Button(variant="primary", value="🎨 Apply Prompts", elem_id="apply")
|
| 105 |
clear = gr.Button(value="❌ Clear all transformations (irreversible)", elem_id="warning")
|
| 106 |
blue_eyes = gr.Slider(
|
|
@@ -110,13 +109,6 @@ with gr.Blocks(css="styles.css") as demo:
|
|
| 110 |
value=0,
|
| 111 |
step=0.1,
|
| 112 |
)
|
| 113 |
-
# hair_green_purple = gr.Slider(
|
| 114 |
-
# label="hair green<->purple ",
|
| 115 |
-
# minimum=-.8,
|
| 116 |
-
# maximum=.8,
|
| 117 |
-
# value=0,
|
| 118 |
-
# step=0.1,
|
| 119 |
-
# )
|
| 120 |
lip_size = gr.Slider(
|
| 121 |
label="Lip Size",
|
| 122 |
minimum=-1.9,
|
|
@@ -131,10 +123,6 @@ with gr.Blocks(css="styles.css") as demo:
|
|
| 131 |
maximum=1.,
|
| 132 |
step=0.1,
|
| 133 |
)
|
| 134 |
-
# requantize = gr.Checkbox(
|
| 135 |
-
# label="Requantize Latents (necessary using text prompts)",
|
| 136 |
-
# value=True,
|
| 137 |
-
# )
|
| 138 |
asian_weight = gr.Slider(
|
| 139 |
minimum=-2.,
|
| 140 |
value=0,
|
|
|
|
| 100 |
label="Rewind back through a prompt transform: Use this to scroll through the iterations of your prompt transformation.",
|
| 101 |
minimum=0,
|
| 102 |
maximum=100)
|
|
|
|
| 103 |
apply_prompts = gr.Button(variant="primary", value="🎨 Apply Prompts", elem_id="apply")
|
| 104 |
clear = gr.Button(value="❌ Clear all transformations (irreversible)", elem_id="warning")
|
| 105 |
blue_eyes = gr.Slider(
|
|
|
|
| 109 |
value=0,
|
| 110 |
step=0.1,
|
| 111 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
lip_size = gr.Slider(
|
| 113 |
label="Lip Size",
|
| 114 |
minimum=-1.9,
|
|
|
|
| 123 |
maximum=1.,
|
| 124 |
step=0.1,
|
| 125 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
asian_weight = gr.Slider(
|
| 127 |
minimum=-2.,
|
| 128 |
value=0,
|
backend.py
CHANGED
|
@@ -33,7 +33,7 @@ def get_resized_tensor(x):
|
|
| 33 |
class ProcessorGradientFlow():
|
| 34 |
"""
|
| 35 |
This wraps the huggingface CLIP processor to allow backprop through the image processing step.
|
| 36 |
-
The original processor forces conversion to PIL images, which breaks gradient flow.
|
| 37 |
"""
|
| 38 |
def __init__(self, device="cuda") -> None:
|
| 39 |
self.device = device
|
|
|
|
| 33 |
class ProcessorGradientFlow():
|
| 34 |
"""
|
| 35 |
This wraps the huggingface CLIP processor to allow backprop through the image processing step.
|
| 36 |
+
The original processor forces conversion to numpy then PIL images, which is faster for image processing but breaks gradient flow.
|
| 37 |
"""
|
| 38 |
def __init__(self, device="cuda") -> None:
|
| 39 |
self.device = device
|