adaface-neurips commited on
Commit
bccf74a
·
1 Parent(s): 566ec8f

Extend CLIP text encoder to support 97 tokens

Browse files
Files changed (3) hide show
  1. adaface/adaface_wrapper.py +21 -3
  2. adaface/util.py +20 -0
  3. app.py +5 -1
adaface/adaface_wrapper.py CHANGED
@@ -14,7 +14,7 @@ from diffusers import (
14
  LCMScheduler,
15
  )
16
  from diffusers.loaders.single_file_utils import convert_ldm_unet_checkpoint
17
- from adaface.util import UNetEnsemble
18
  from adaface.face_id_to_ada_prompt import create_id2ada_prompt_encoder
19
  from adaface.diffusers_attn_lora_capture import set_up_attn_processors, set_up_ffn_loras, set_lora_and_capture_flags
20
  from safetensors.torch import load_file as safetensors_load_file
@@ -27,7 +27,7 @@ class AdaFaceWrapper(nn.Module):
27
  adaface_ckpt_paths, adaface_encoder_cfg_scales=None,
28
  enabled_encoders=None, use_lcm=False, default_scheduler_name='ddim',
29
  num_inference_steps=50, subject_string='z', negative_prompt=None,
30
- use_840k_vae=False, use_ds_text_encoder=False,
31
  main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
32
  enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
33
  attn_lora_layer_names=['q', 'k', 'v', 'out'], normalize_cross_attn=False, q_lora_updates_query=False,
@@ -56,6 +56,9 @@ class AdaFaceWrapper(nn.Module):
56
 
57
  self.default_scheduler_name = default_scheduler_name
58
  self.num_inference_steps = num_inference_steps if not use_lcm else 4
 
 
 
59
  self.use_840k_vae = use_840k_vae
60
  self.use_ds_text_encoder = use_ds_text_encoder
61
  self.main_unet_filepath = main_unet_filepath
@@ -199,6 +202,21 @@ class AdaFaceWrapper(nn.Module):
199
 
200
  pipeline.unet = unet2
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  if self.use_840k_vae:
203
  pipeline.vae = vae
204
  print("Replaced the VAE with the 840k-step VAE.")
@@ -715,7 +733,7 @@ class AdaFaceWrapper(nn.Module):
715
  ref_img_strength=0.8, generator=None,
716
  ablate_prompt_only_placeholders=False,
717
  ablate_prompt_no_placeholders=False,
718
- ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img'
719
  nonmix_prompt_emb_weight=0,
720
  repeat_prompt_for_each_encoder=True,
721
  verbose=False):
 
14
  LCMScheduler,
15
  )
16
  from diffusers.loaders.single_file_utils import convert_ldm_unet_checkpoint
17
+ from adaface.util import UNetEnsemble, extend_nn_embedding
18
  from adaface.face_id_to_ada_prompt import create_id2ada_prompt_encoder
19
  from adaface.diffusers_attn_lora_capture import set_up_attn_processors, set_up_ffn_loras, set_lora_and_capture_flags
20
  from safetensors.torch import load_file as safetensors_load_file
 
27
  adaface_ckpt_paths, adaface_encoder_cfg_scales=None,
28
  enabled_encoders=None, use_lcm=False, default_scheduler_name='ddim',
29
  num_inference_steps=50, subject_string='z', negative_prompt=None,
30
+ max_prompt_length=77, use_840k_vae=False, use_ds_text_encoder=False,
31
  main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
32
  enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
33
  attn_lora_layer_names=['q', 'k', 'v', 'out'], normalize_cross_attn=False, q_lora_updates_query=False,
 
56
 
57
  self.default_scheduler_name = default_scheduler_name
58
  self.num_inference_steps = num_inference_steps if not use_lcm else 4
59
+
60
+ self.max_prompt_length = max_prompt_length
61
+
62
  self.use_840k_vae = use_840k_vae
63
  self.use_ds_text_encoder = use_ds_text_encoder
64
  self.main_unet_filepath = main_unet_filepath
 
202
 
203
  pipeline.unet = unet2
204
 
205
+ # Extending prompt length is for SD 1.5 only.
206
+ if (self.pipeline_name == "text2img") and (self.max_prompt_length > 77):
207
+ # pipeline.text_encoder.text_model.embeddings.position_embedding.weight: [77, 768] -> [max_length, 768]
208
+ # We reuse the last EL position embeddings for the new position embeddings.
209
+ # If we use the "neat" way, i.e., initialize CLIPTextModel with a CLIPTextConfig with
210
+ # a larger max_position_embeddings, and set ignore_mismatched_sizes=True,
211
+ # then the old position embeddings won't be loaded from the pretrained ckpt,
212
+ # leading to degenerated performance.
213
+ EL = self.max_prompt_length - 77
214
+ # position_embedding.weight: [77, 768] -> [max_length, 768]
215
+ new_position_embedding = extend_nn_embedding(pipeline.text_encoder.text_model.embeddings.position_embedding,
216
+ pipeline.text_encoder.text_model.embeddings.position_embedding.weight[-EL:])
217
+ pipeline.text_encoder.text_model.embeddings.position_embedding = new_position_embedding
218
+ pipeline.text_encoder.text_model.embeddings.position_ids = torch.arange(self.max_prompt_length).unsqueeze(0)
219
+
220
  if self.use_840k_vae:
221
  pipeline.vae = vae
222
  print("Replaced the VAE with the 840k-step VAE.")
 
733
  ref_img_strength=0.8, generator=None,
734
  ablate_prompt_only_placeholders=False,
735
  ablate_prompt_no_placeholders=False,
736
+ ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img1', 'img2'.
737
  nonmix_prompt_emb_weight=0,
738
  repeat_prompt_for_each_encoder=True,
739
  verbose=False):
adaface/util.py CHANGED
@@ -73,6 +73,26 @@ def calc_stats(emb_name, embeddings, mean_dim=-1):
73
  print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std()))
74
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  # Revised from RevGrad, by removing the grad negation.
77
  class ScaleGrad(torch.autograd.Function):
78
  @staticmethod
 
73
  print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std()))
74
 
75
 
76
+ # new_token_embeddings: [new_num_tokens, 768].
77
+ def extend_nn_embedding(old_nn_embedding, new_token_embeddings):
78
+ emb_dim = old_nn_embedding.embedding_dim
79
+ num_old_tokens = old_nn_embedding.num_embeddings
80
+ num_new_tokens = new_token_embeddings.shape[0]
81
+ num_tokens2 = num_old_tokens + num_new_tokens
82
+
83
+ new_nn_embedding = nn.Embedding(num_tokens2, emb_dim,
84
+ device=old_nn_embedding.weight.device,
85
+ dtype=old_nn_embedding.weight.dtype)
86
+
87
+ old_num_tokens = old_nn_embedding.weight.shape[0]
88
+ # Copy the first old_num_tokens embeddings from old_nn_embedding to new_nn_embedding.
89
+ new_nn_embedding.weight.data[:old_num_tokens] = old_nn_embedding.weight.data
90
+ # Copy the new embeddings to new_nn_embedding.
91
+ new_nn_embedding.weight.data[old_num_tokens:] = new_token_embeddings
92
+
93
+ print(f"Extended nn.Embedding from {num_old_tokens} to {num_tokens2} tokens.")
94
+ return new_nn_embedding
95
+
96
  # Revised from RevGrad, by removing the grad negation.
97
  class ScaleGrad(torch.autograd.Function):
98
  @staticmethod
app.py CHANGED
@@ -34,6 +34,8 @@ parser.add_argument('--num_inference_steps', type=int, default=50,
34
  parser.add_argument('--ablate_prompt_embed_type', type=str, default='ada',
35
  choices=["ada", "arc2face", "consistentID"],
36
  help="Ablate to use the image ID embs instead of Ada embs")
 
 
37
 
38
  parser.add_argument('--gpu', type=int, default=None)
39
  parser.add_argument('--ip', type=str, default="0.0.0.0")
@@ -79,6 +81,7 @@ adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=adaface_base_
79
  adaface_encoder_types=args.adaface_encoder_types,
80
  adaface_ckpt_paths=args.adaface_ckpt_path, device='cpu',
81
  num_inference_steps=args.num_inference_steps,
 
82
  is_on_hf_space=is_on_hf_space)
83
 
84
  basedir = os.getcwd()
@@ -208,7 +211,7 @@ def generate_video(image_container, uploaded_image_paths, init_img_file_paths, i
208
  if args.ablate_prompt_embed_type != "ada":
209
  # Find the prompt_emb_type index in adaface_encoder_types
210
  # adaface_encoder_types: ["consistentID", "arc2face"]
211
- ablate_prompt_embed_index = args.adaface_encoder_types.index(args.ablate_prompt_embed_type)
212
  ablate_prompt_embed_type = f"img{ablate_prompt_embed_index}"
213
  else:
214
  ablate_prompt_embed_type = "ada"
@@ -270,6 +273,7 @@ def check_prompt_and_model_type(prompt, model_style_type, progress=gr.Progress()
270
  adaface_encoder_types=args.adaface_encoder_types,
271
  adaface_ckpt_paths=[args.adaface_ckpt_path], device='cpu',
272
  num_inference_steps=args.num_inference_steps,
 
273
  is_on_hf_space=is_on_hf_space)
274
  # Update base model type.
275
  args.model_style_type = model_style_type
 
34
  parser.add_argument('--ablate_prompt_embed_type', type=str, default='ada',
35
  choices=["ada", "arc2face", "consistentID"],
36
  help="Ablate to use the image ID embs instead of Ada embs")
37
+ parser.add_argument('--max_prompt_length', type=int, default=97,
38
+ help="Maximum length of the prompt. If > 77, the CLIP text encoder will be extended.")
39
 
40
  parser.add_argument('--gpu', type=int, default=None)
41
  parser.add_argument('--ip', type=str, default="0.0.0.0")
 
81
  adaface_encoder_types=args.adaface_encoder_types,
82
  adaface_ckpt_paths=args.adaface_ckpt_path, device='cpu',
83
  num_inference_steps=args.num_inference_steps,
84
+ max_prompt_length=args.max_prompt_length,
85
  is_on_hf_space=is_on_hf_space)
86
 
87
  basedir = os.getcwd()
 
211
  if args.ablate_prompt_embed_type != "ada":
212
  # Find the prompt_emb_type index in adaface_encoder_types
213
  # adaface_encoder_types: ["consistentID", "arc2face"]
214
+ ablate_prompt_embed_index = args.adaface_encoder_types.index(args.ablate_prompt_embed_type) + 1
215
  ablate_prompt_embed_type = f"img{ablate_prompt_embed_index}"
216
  else:
217
  ablate_prompt_embed_type = "ada"
 
273
  adaface_encoder_types=args.adaface_encoder_types,
274
  adaface_ckpt_paths=[args.adaface_ckpt_path], device='cpu',
275
  num_inference_steps=args.num_inference_steps,
276
+ max_prompt_length=args.max_prompt_length,
277
  is_on_hf_space=is_on_hf_space)
278
  # Update base model type.
279
  args.model_style_type = model_style_type