erwann commited on
Commit
2783b1f
·
1 Parent(s): 29bbf75

update gradio demo

Browse files
Files changed (4) hide show
  1. README.md +9 -0
  2. app.py +0 -6
  3. backend.py +20 -28
  4. presets.py +1 -0
README.md CHANGED
@@ -9,4 +9,13 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
 
 
 
 
 
 
 
 
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
9
  pinned: false
10
  ---
11
 
12
+ # Face Editor
13
+ This face editor uses a CelebA pretrained VQGAN with CLIP to allow prompt-based image manipulation, as well as slider based manipulation using extracted latent vectors.
14
+
15
+ I've written a series of Medium articles which provide a detailed and beginner-friendly explanation of how this was built.
16
+
17
+ ## Features:
18
+ Edit masking using custom backpropagation hook
19
+
20
+
21
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -183,17 +183,11 @@ with gr.Blocks(css="styles.css") as demo:
183
  value=3,
184
  step=1,
185
  label="Steps to run at the end of the optimization, optimizing only the masked perceptual loss. If the edit is changing the identity too much, this setting will run steps at the end that 'pull' the image back towards the original identity")
186
- # discriminator_steps = gr.Slider(minimum=0,
187
- # maximum=50,
188
- # step=1,
189
- # value=0,
190
- # label="Steps to run at the end, optimizing only the discriminator loss. This helps to reduce artefacts, but because the model is trained on CelebA, this will make your generations look more like generic white celebrities")
191
  clear.click(StateWrapper.clear_transforms, inputs=[state], outputs=[state, out, mask])
192
  asian_weight.change(StateWrapper.apply_asian_vector, inputs=[state, asian_weight], outputs=[state, out, mask])
193
  lip_size.change(StateWrapper.apply_lip_vector, inputs=[state, lip_size], outputs=[state, out, mask])
194
  blue_eyes.change(StateWrapper.apply_rb_vector, inputs=[state, blue_eyes], outputs=[state, out, mask])
195
  blend_weight.change(StateWrapper.blend, inputs=[state, blend_weight], outputs=[state, out, mask])
196
- # requantize.change(StateWrapper.update_requant, inputs=[state, requantize], outputs=[state, out, mask])
197
  base_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
198
  blend_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
199
  apply_prompts.click(StateWrapper.apply_prompts, inputs=[state, positive_prompts, negative_prompts, learning_rate, iterations, lpips_weight, reconstruction_steps], outputs=[state, out, mask])
 
183
  value=3,
184
  step=1,
185
  label="Steps to run at the end of the optimization, optimizing only the masked perceptual loss. If the edit is changing the identity too much, this setting will run steps at the end that 'pull' the image back towards the original identity")
 
 
 
 
 
186
  clear.click(StateWrapper.clear_transforms, inputs=[state], outputs=[state, out, mask])
187
  asian_weight.change(StateWrapper.apply_asian_vector, inputs=[state, asian_weight], outputs=[state, out, mask])
188
  lip_size.change(StateWrapper.apply_lip_vector, inputs=[state, lip_size], outputs=[state, out, mask])
189
  blue_eyes.change(StateWrapper.apply_rb_vector, inputs=[state, blue_eyes], outputs=[state, out, mask])
190
  blend_weight.change(StateWrapper.blend, inputs=[state, blend_weight], outputs=[state, out, mask])
 
191
  base_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
192
  blend_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
193
  apply_prompts.click(StateWrapper.apply_prompts, inputs=[state, positive_prompts, negative_prompts, learning_rate, iterations, lpips_weight, reconstruction_steps], outputs=[state, out, mask])
backend.py CHANGED
@@ -79,7 +79,7 @@ class ImagePromptEditor(nn.Module):
79
  self.latent = latent.detach().to(self.device)
80
 
81
  def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
82
- self._attn_mask = attn_mask
83
  self.iterations = iterations
84
  self.lr = lr
85
  self.lpips_weight = lpips_weight
@@ -118,25 +118,16 @@ class ImagePromptEditor(nn.Module):
118
  loss = -torch.log(pos_logits) + torch.log(neg_logits)
119
  return loss
120
 
121
- def visualize(self, processed_img):
122
- if self.make_grid:
123
- self.index += 1
124
- plt.subplot(1, 13, self.index)
125
- plt.imshow(get_pil(processed_img[0]).detach().cpu())
126
- else:
127
- plt.imshow(get_pil(processed_img[0]).detach().cpu())
128
- plt.show()
129
-
130
- def _attn_mask(self, grad):
131
  newgrad = grad
132
- if self._attn_mask is not None:
133
- newgrad = grad * (self._attn_mask)
134
  return newgrad
135
 
136
- def _attn_mask_inverse(self, grad):
137
  newgrad = grad
138
- if self._attn_mask is not None:
139
- newgrad = grad * ((self._attn_mask - 1) * -1)
140
  return newgrad
141
 
142
  def _get_next_inputs(self, transformed_img):
@@ -144,11 +135,11 @@ class ImagePromptEditor(nn.Module):
144
  processed_img.retain_grad()
145
 
146
  lpips_input = processed_img.clone()
147
- lpips_input.register_hook(self._attn_mask_inverse)
148
  lpips_input.retain_grad()
149
 
150
  clip_input = processed_img.clone()
151
- clip_input.register_hook(self._attn_mask)
152
  clip_input.retain_grad()
153
 
154
  return (processed_img, lpips_input, clip_input)
@@ -160,15 +151,15 @@ class ImagePromptEditor(nn.Module):
160
  processed_img, lpips_input, clip_input = self._get_next_inputs(
161
  transformed_img
162
  )
163
- with torch.autocast("cuda"):
164
- clip_loss = self._get_CLIP_loss(pos_prompts, neg_prompts, clip_input)
165
- print("CLIP loss", clip_loss)
166
- perceptual_loss = (
167
- self.perceptual_loss(lpips_input, original_img.clone())
168
- * self.lpips_weight
169
- )
170
- print("LPIPS loss: ", perceptual_loss)
171
- print("Sum Loss", perceptual_loss + clip_loss)
172
  if log:
173
  wandb.log({"Perceptual Loss": perceptual_loss})
174
  wandb.log({"CLIP Loss": clip_loss})
@@ -188,7 +179,7 @@ class ImagePromptEditor(nn.Module):
188
  processed_img.retain_grad()
189
 
190
  lpips_input = processed_img.clone()
191
- lpips_input.register_hook(self._attn_mask_inverse)
192
  lpips_input.retain_grad()
193
  with torch.autocast("cuda"):
194
  perceptual_loss = (
@@ -217,4 +208,5 @@ class ImagePromptEditor(nn.Module):
217
  print("Running LPIPS optim only")
218
  for transform in self._optimize_LPIPS(vector, original_img, optim):
219
  yield transform
 
220
  yield vector if self.return_val == "vector" else self.latent + vector
 
79
  self.latent = latent.detach().to(self.device)
80
 
81
  def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
82
+ self.attn_mask = attn_mask
83
  self.iterations = iterations
84
  self.lr = lr
85
  self.lpips_weight = lpips_weight
 
118
  loss = -torch.log(pos_logits) + torch.log(neg_logits)
119
  return loss
120
 
121
+ def _apply_mask(self, grad):
 
 
 
 
 
 
 
 
 
122
  newgrad = grad
123
+ if self.attn_mask is not None:
124
+ newgrad = grad * (self.attn_mask)
125
  return newgrad
126
 
127
+ def _apply_inverse_mask(self, grad):
128
  newgrad = grad
129
+ if self.attn_mask is not None:
130
+ newgrad = grad * ((self.attn_mask - 1) * -1)
131
  return newgrad
132
 
133
  def _get_next_inputs(self, transformed_img):
 
135
  processed_img.retain_grad()
136
 
137
  lpips_input = processed_img.clone()
138
+ lpips_input.register_hook(self._apply_inverse_mask)
139
  lpips_input.retain_grad()
140
 
141
  clip_input = processed_img.clone()
142
+ clip_input.register_hook(self._apply_mask)
143
  clip_input.retain_grad()
144
 
145
  return (processed_img, lpips_input, clip_input)
 
151
  processed_img, lpips_input, clip_input = self._get_next_inputs(
152
  transformed_img
153
  )
154
+ # with torch.autocast("cuda"):
155
+ clip_loss = self._get_CLIP_loss(pos_prompts, neg_prompts, clip_input)
156
+ print("CLIP loss", clip_loss)
157
+ perceptual_loss = (
158
+ self.perceptual_loss(lpips_input, original_img.clone())
159
+ * self.lpips_weight
160
+ )
161
+ print("LPIPS loss: ", perceptual_loss)
162
+ print("Sum Loss", perceptual_loss + clip_loss)
163
  if log:
164
  wandb.log({"Perceptual Loss": perceptual_loss})
165
  wandb.log({"CLIP Loss": clip_loss})
 
179
  processed_img.retain_grad()
180
 
181
  lpips_input = processed_img.clone()
182
+ lpips_input.register_hook(self._apply_inverse_mask)
183
  lpips_input.retain_grad()
184
  with torch.autocast("cuda"):
185
  perceptual_loss = (
 
208
  print("Running LPIPS optim only")
209
  for transform in self._optimize_LPIPS(vector, original_img, optim):
210
  yield transform
211
+
212
  yield vector if self.return_val == "vector" else self.latent + vector
presets.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
 
3
 
4
  def set_preset(config_str):
 
5
  choices = [
6
  "Small Masked Changes (e.g. add lipstick)",
7
  "Major Masked Changes (e.g. change hair color or nose size)",
 
2
 
3
 
4
  def set_preset(config_str):
5
+ print(config_str)
6
  choices = [
7
  "Small Masked Changes (e.g. add lipstick)",
8
  "Major Masked Changes (e.g. change hair color or nose size)",