erwann commited on
Commit
6859b0d
·
1 Parent(s): 9d59265
Files changed (4) hide show
  1. ImageState.py +7 -3
  2. animation.py +0 -17
  3. app.py +51 -53
  4. configs.py +10 -2
ImageState.py CHANGED
@@ -78,7 +78,7 @@ class ImageState:
78
  def clear_transforms(self):
79
  global num
80
  self.init_transforms()
81
- clear_img_dir()
82
  num = 0
83
  return self._render_all_transformations()
84
  def _apply_vector(self, src, vector):
@@ -151,7 +151,11 @@ class ImageState:
151
  return self._render_all_transformations()
152
  def update_images(self, path1, path2, blend_weight):
153
  if path1 is None and path2 is None:
 
154
  return None
 
 
 
155
  if path1 is None: path1 = path2
156
  if path2 is None: path2 = path1
157
  self.path1, self.path2 = path1, path2
@@ -170,7 +174,7 @@ class ImageState:
170
  prompt_transform = self.transform_history[-1]
171
  latent_index = int(index / 100 * (prompt_transform.iterations - 1))
172
  print(latent_index)
173
- self.current_prompt_transforms[-1] = prompt_transform.transforms[latent_index]
174
  return self._render_all_transformations()
175
  # def rescale_mask(self, mask):
176
  # rep = mask.clone()
@@ -196,7 +200,7 @@ class ImageState:
196
  for i, transform in enumerate(self.prompt_optim.optimize(self.blend_latent,
197
  positive_prompts,
198
  negative_prompts)):
199
- transform_log.transforms.append(transform.clone().detach())
200
  self.current_prompt_transforms[-1] = transform
201
  with torch.no_grad():
202
  image = self._render_all_transformations(return_twice=False)
 
78
  def clear_transforms(self):
79
  global num
80
  self.init_transforms()
81
+ clear_img_dir("./img_history")
82
  num = 0
83
  return self._render_all_transformations()
84
  def _apply_vector(self, src, vector):
 
151
  return self._render_all_transformations()
152
  def update_images(self, path1, path2, blend_weight):
153
  if path1 is None and path2 is None:
154
+ print("no paths")
155
  return None
156
+ if path1 == path2:
157
+ print("paths are the same")
158
+ print(path1)
159
  if path1 is None: path1 = path2
160
  if path2 is None: path2 = path1
161
  self.path1, self.path2 = path1, path2
 
174
  prompt_transform = self.transform_history[-1]
175
  latent_index = int(index / 100 * (prompt_transform.iterations - 1))
176
  print(latent_index)
177
+ self.current_prompt_transforms[-1] = prompt_transform.transforms[latent_index].to(self.device)
178
  return self._render_all_transformations()
179
  # def rescale_mask(self, mask):
180
  # rep = mask.clone()
 
200
  for i, transform in enumerate(self.prompt_optim.optimize(self.blend_latent,
201
  positive_prompts,
202
  negative_prompts)):
203
+ transform_log.transforms.append(transform.detach().cpu())
204
  self.current_prompt_transforms[-1] = transform
205
  with torch.no_grad():
206
  image = self._render_all_transformations(return_twice=False)
animation.py CHANGED
@@ -9,23 +9,6 @@ def clear_img_dir(img_dir):
9
  os.remove(filename)
10
 
11
 
12
- def create_gif(total_duration, extend_frames, folder, gif_name="face_edit.gif"):
13
- images = []
14
- paths = glob.glob(folder + "/*")
15
- frame_duration = total_duration / len(paths)
16
- print(len(paths), "frame dur", frame_duration)
17
- durations = [frame_duration] * len(paths)
18
- if extend_frames:
19
- durations [0] = 1.5
20
- durations [-1] = 3
21
- for file_name in os.listdir(folder):
22
- if file_name.endswith('.png'):
23
- file_path = os.path.join(folder, file_name)
24
- images.append(imageio.imread(file_path))
25
- # images[0] = images[0].set_meta_data({'duration': 1})
26
- # images[-1] = images[-1].set_meta_data({'duration': 1})
27
- imageio.mimsave(gif_name, images, duration=durations)
28
- return gif_name
29
 
30
  if __name__ == "__main__":
31
  # clear_img_dir()
 
9
  os.remove(filename)
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  if __name__ == "__main__":
14
  # clear_img_dir()
app.py CHANGED
@@ -5,7 +5,7 @@ import sys
5
  import wandb
6
  import torch
7
 
8
- from configs import set_major_global, set_major_local, set_small_local
9
  import uuid
10
  # print()'
11
  sys.path.append("taming-transformers")
@@ -18,7 +18,7 @@ import edit
18
  from backend import ImagePromptOptimizer, ProcessorGradientFlow
19
  from ImageState import ImageState
20
  from loaders import load_default
21
- from animation import create_gif
22
  from prompts import get_random_prompts
23
 
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -70,16 +70,40 @@ class StateWrapper:
70
  return state, *state[0].update_images(*args, **kwargs)
71
  def update_requant(state, *args, **kwargs):
72
  return state, *state[0].update_requant(*args, **kwargs)
73
- def ret_id(id):
74
- print(id)
75
- return(id)
76
  with gr.Blocks(css="styles.css") as demo:
77
  id = gr.State(str(uuid.uuid4()))
78
  state = gr.State([ImageState(vqgan, promptoptim), str(uuid.uuid4())])
79
  with gr.Row():
80
  with gr.Column(scale=1):
81
- x = gr.Button(label="asd")
82
- x.click(ret_id, inputs=id, outputs=id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  blue_eyes = gr.Slider(
84
  label="Blue Eyes",
85
  minimum=-.8,
@@ -119,39 +143,6 @@ with gr.Blocks(css="styles.css") as demo:
119
  maximum=2.,
120
  step=0.07,
121
  )
122
- with gr.Row():
123
- with gr.Column():
124
- gr.Markdown(value="""## Image Upload
125
- For best results, crop the photos like in the example pictures""", show_label=False)
126
- with gr.Row():
127
- base_img = gr.Image(label="Base Image", type="filepath")
128
- blend_img = gr.Image(label="Image for face blending (optional)", type="filepath")
129
- # gr.Markdown("## Image Examples")
130
- with gr.Accordion(label="Add Mask", open=False):
131
- mask = gr.Image(tool="sketch", interactive=True)
132
- gr.Markdown(value="Note: You must clear the mask using the rewind button every time you want to change the mask (this is a gradio bug)")
133
- set_mask = gr.Button(value="Set mask")
134
- gr.Text(value="this image shows the mask passed to the model when you press set mask (debugging purposes)")
135
- testim = gr.Image()
136
- # # clear_mask = gr.Button(value="Clear mask")
137
- # clear_mask.click(get_cleared_mask, outputs=mask)
138
- with gr.Row():
139
- gr.Examples(
140
- examples=glob.glob("test_pics/*"),
141
- inputs=base_img,
142
- outputs=blend_img,
143
- fn=set_img_from_example,
144
- # cache_examples=True,
145
- )
146
- with gr.Column(scale=1):
147
- out = gr.Image()
148
- rewind = gr.Slider(value=100,
149
- label="Rewind back through a prompt transform: Use this to scroll through the iterations of your prompt transformation.",
150
- minimum=0,
151
- maximum=100)
152
-
153
- apply_prompts = gr.Button(variant="primary", value="🎨 Apply Prompts", elem_id="apply")
154
- clear = gr.Button(value="❌ Clear all transformations (irreversible)", elem_id="warning")
155
  with gr.Accordion(label="💾 Save Animation", open=False):
156
  gr.Text(value="Creates an animation of all the steps in the editing process", show_label=False)
157
  duration = gr.Number(value=10, label="Duration of the animation in seconds")
@@ -161,7 +152,7 @@ with gr.Blocks(css="styles.css") as demo:
161
  create_animation.click(StateWrapper.create_gif, inputs=[state, duration, extend_frames], outputs=[state, gif])
162
 
163
  with gr.Column(scale=1):
164
- gr.Markdown(value="""## Text Prompting
165
  See readme for a prompting guide. Use the '|' symbol to separate prompts. Use the "Add mask" section to make local edits. Negative prompts are highly recommended""", show_label=False)
166
  positive_prompts = gr.Textbox(label="Positive prompts",
167
  value="a picture of a woman with a very big nose | a picture of a woman with a large wide nose | a woman with an extremely prominent nose")
@@ -171,29 +162,35 @@ with gr.Blocks(css="styles.css") as demo:
171
  gen_prompts.click(get_random_prompts, outputs=[positive_prompts, negative_prompts])
172
  with gr.Row():
173
  with gr.Column():
174
- gr.Text(value="⚙️ Prompt Editing Configuration", show_label=False)
175
  with gr.Row():
176
- gr.Markdown(value="## Preset Configs", show_label=False)
 
 
 
 
 
 
 
 
 
 
 
177
  with gr.Row():
178
  # with gr.Column():
179
- small_local = gr.Button(value="Small Masked Changes (e.g. add lipstick)", elem_id="small_local").style(full_width=False)
180
- # with gr.Column():
181
- major_local = gr.Button(value="Major Masked Changes (e.g. change hair color or nose size)").style(full_width=False)
182
- # with gr.Column():
183
- major_global = gr.Button(value="Major Global Changes (e.g. change race / gender").style(full_width=False)
184
  iterations = gr.Slider(minimum=10,
185
  maximum=60,
186
  step=1,
187
  value=20,
188
  label="Iterations: How many steps the model will take to modify the image. Try starting small and seeing how the results turn out, you can always resume with afterwards",)
189
  learning_rate = gr.Slider(minimum=4e-3,
190
- maximum=7e-1,
191
  value=1e-1,
192
  label="Learning Rate: How strong the change in each step will be (you should raise this for bigger changes (for example, changing hair color), and lower it for more minor changes. Raise if changes aren't strong enough")
193
  lpips_weight = gr.Slider(minimum=0,
194
  maximum=50,
195
  value=1,
196
- label="Perceptual similarity weight (Keeps areas outside of the mask looking similar to the original. Increase if the rest of the image is changing too much while you're trying to change make a localized edit")
197
  reconstruction_steps = gr.Slider(minimum=0,
198
  maximum=50,
199
  value=3,
@@ -213,11 +210,12 @@ with gr.Blocks(css="styles.css") as demo:
213
  # requantize.change(StateWrapper.update_requant, inputs=[state, requantize], outputs=[state, out, mask])
214
  base_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
215
  blend_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
216
- small_local.click(set_small_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
217
- major_local.click(set_major_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
218
- major_global.click(set_major_global, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
219
  apply_prompts.click(StateWrapper.apply_prompts, inputs=[state, positive_prompts, negative_prompts, learning_rate, iterations, lpips_weight, reconstruction_steps], outputs=[state, out, mask])
220
  rewind.change(StateWrapper.rewind, inputs=[state, rewind], outputs=[state, out, mask])
221
  set_mask.click(StateWrapper.set_mask, inputs=[state, mask], outputs=[state, testim])
 
222
  demo.queue()
223
  demo.launch(debug=True, enable_queue=True)
 
5
  import wandb
6
  import torch
7
 
8
+ from configs import set_major_global, set_major_local, set_preset, set_small_local
9
  import uuid
10
  # print()'
11
  sys.path.append("taming-transformers")
 
18
  from backend import ImagePromptOptimizer, ProcessorGradientFlow
19
  from ImageState import ImageState
20
  from loaders import load_default
21
+ # from animation import create_gif
22
  from prompts import get_random_prompts
23
 
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
70
  return state, *state[0].update_images(*args, **kwargs)
71
  def update_requant(state, *args, **kwargs):
72
  return state, *state[0].update_requant(*args, **kwargs)
 
 
 
73
  with gr.Blocks(css="styles.css") as demo:
74
  id = gr.State(str(uuid.uuid4()))
75
  state = gr.State([ImageState(vqgan, promptoptim), str(uuid.uuid4())])
76
  with gr.Row():
77
  with gr.Column(scale=1):
78
+ with gr.Row():
79
+ with gr.Column():
80
+ gr.Markdown(value="""## Image Upload
81
+ For best results, crop the photos like in the example pictures""", show_label=False)
82
+ with gr.Row():
83
+ base_img = gr.Image(label="Base Image", type="filepath")
84
+ blend_img = gr.Image(label="Image for face blending (optional)", type="filepath")
85
+ with gr.Accordion(label="Add Mask", open=False):
86
+ mask = gr.Image(tool="sketch", interactive=True)
87
+ gr.Markdown(value="Note: You must clear the mask using the rewind button every time you want to change the mask (this is a gradio issue)")
88
+ set_mask = gr.Button(value="Set mask")
89
+ gr.Text(value="this image shows the mask passed to the model when you press set mask (debugging purposes)")
90
+ testim = gr.Image()
91
+ with gr.Row():
92
+ gr.Examples(
93
+ examples=glob.glob("test_pics/*"),
94
+ inputs=base_img,
95
+ outputs=blend_img,
96
+ fn=set_img_from_example,
97
+ )
98
+ with gr.Column(scale=1):
99
+ out = gr.Image()
100
+ rewind = gr.Slider(value=100,
101
+ label="Rewind back through a prompt transform: Use this to scroll through the iterations of your prompt transformation.",
102
+ minimum=0,
103
+ maximum=100)
104
+
105
+ apply_prompts = gr.Button(variant="primary", value="🎨 Apply Prompts", elem_id="apply")
106
+ clear = gr.Button(value="❌ Clear all transformations (irreversible)", elem_id="warning")
107
  blue_eyes = gr.Slider(
108
  label="Blue Eyes",
109
  minimum=-.8,
 
143
  maximum=2.,
144
  step=0.07,
145
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  with gr.Accordion(label="💾 Save Animation", open=False):
147
  gr.Text(value="Creates an animation of all the steps in the editing process", show_label=False)
148
  duration = gr.Number(value=10, label="Duration of the animation in seconds")
 
152
  create_animation.click(StateWrapper.create_gif, inputs=[state, duration, extend_frames], outputs=[state, gif])
153
 
154
  with gr.Column(scale=1):
155
+ gr.Markdown(value="""## ✍️ Prompt Editing
156
  See readme for a prompting guide. Use the '|' symbol to separate prompts. Use the "Add mask" section to make local edits. Negative prompts are highly recommended""", show_label=False)
157
  positive_prompts = gr.Textbox(label="Positive prompts",
158
  value="a picture of a woman with a very big nose | a picture of a woman with a large wide nose | a woman with an extremely prominent nose")
 
162
  gen_prompts.click(get_random_prompts, outputs=[positive_prompts, negative_prompts])
163
  with gr.Row():
164
  with gr.Column():
 
165
  with gr.Row():
166
+ gr.Markdown(value="## Prompt Editing Config", show_label=False)
167
+ with gr.Accordion(label="Config Tutorial", open=False):
168
+ gr.Markdown(value="""
169
+ - If results are not changing enough, increase the learning rate or decrease the perceptual loss weight
170
+ - To make local edits, use the 'Add Mask' section
171
+ - If using a mask and the image is changing too much outside of the masked area, try increasing the perceptual loss weight or lowering the learning rate
172
+ - Use the rewind slider to scroll through the iterations of your prompt transformation, you can resume editing from any point in the history.
173
+ - I recommend starting prompts with 'a picture of a'
174
+ - To avoid shifts in gender, you can use 'a person' instead of 'a man' or 'a woman', especially in the negative prompts.
175
+ - The more 'out-of-domain' the prompts are, the more you need to increase the learning rate and decrease the perceptual loss weight. For example, trying to make a black person have platinum blond hair is more out-of-domain than the same transformation on a caucasian person.
176
+ -
177
+ """)
178
  with gr.Row():
179
  # with gr.Column():
180
+ presets = gr.Dropdown(default="Select a preset", label="Preset Configs", choices=["Small Masked Changes (e.g. add lipstick)", "Major Masked Changes (e.g. change hair color or nose size)", "Major Global Changes (e.g. change race / gender"])
 
 
 
 
181
  iterations = gr.Slider(minimum=10,
182
  maximum=60,
183
  step=1,
184
  value=20,
185
  label="Iterations: How many steps the model will take to modify the image. Try starting small and seeing how the results turn out, you can always resume with afterwards",)
186
  learning_rate = gr.Slider(minimum=4e-3,
187
+ maximum=1,
188
  value=1e-1,
189
  label="Learning Rate: How strong the change in each step will be (you should raise this for bigger changes (for example, changing hair color), and lower it for more minor changes. Raise if changes aren't strong enough")
190
  lpips_weight = gr.Slider(minimum=0,
191
  maximum=50,
192
  value=1,
193
+ label="Perceptual Loss weight (Keeps areas outside of the mask looking similar to the original. Increase if the rest of the image is changing too much while you're trying to change make a localized edit")
194
  reconstruction_steps = gr.Slider(minimum=0,
195
  maximum=50,
196
  value=3,
 
210
  # requantize.change(StateWrapper.update_requant, inputs=[state, requantize], outputs=[state, out, mask])
211
  base_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
212
  blend_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
213
+ # small_local.click(set_small_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
214
+ # major_local.click(set_major_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
215
+ # major_global.click(set_major_global, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
216
  apply_prompts.click(StateWrapper.apply_prompts, inputs=[state, positive_prompts, negative_prompts, learning_rate, iterations, lpips_weight, reconstruction_steps], outputs=[state, out, mask])
217
  rewind.change(StateWrapper.rewind, inputs=[state, rewind], outputs=[state, out, mask])
218
  set_mask.click(StateWrapper.set_mask, inputs=[state, mask], outputs=[state, testim])
219
+ presets.change(set_preset, inputs=[presets], outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
220
  demo.queue()
221
  demo.launch(debug=True, enable_queue=True)
configs.py CHANGED
@@ -1,7 +1,15 @@
1
  import gradio as gr
2
  def set_small_local():
3
- return (gr.Slider.update(value=25), gr.Slider.update(value=0.15), gr.Slider.update(value=5), gr.Slider.update(value=4))
4
  def set_major_local():
5
  return (gr.Slider.update(value=25), gr.Slider.update(value=0.187), gr.Slider.update(value=36.6), gr.Slider.update(value=6))
6
  def set_major_global():
7
- return (gr.Slider.update(value=30), gr.Slider.update(value=0.1), gr.Slider.update(value=1), gr.Slider.update(value=1))
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  def set_small_local():
3
+ return (gr.Slider.update(value=18), gr.Slider.update(value=0.15), gr.Slider.update(value=5), gr.Slider.update(value=4))
4
  def set_major_local():
5
  return (gr.Slider.update(value=25), gr.Slider.update(value=0.187), gr.Slider.update(value=36.6), gr.Slider.update(value=6))
6
  def set_major_global():
7
+ return (gr.Slider.update(value=30), gr.Slider.update(value=0.1), gr.Slider.update(value=1), gr.Slider.update(value=1))
8
+ def set_preset(config_str):
9
+ choices=["Small Masked Changes (e.g. add lipstick)", "Major Masked Changes (e.g. change hair color or nose size)", "Major Global Changes (e.g. change race / gender"]
10
+ if config_str == choices[0]:
11
+ return set_small_local()
12
+ elif config_str == choices[1]:
13
+ return set_major_local()
14
+ elif config_str == choices[2]:
15
+ return set_major_global()