Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
bccf74a
1
Parent(s):
566ec8f
Extend CLIP text encoder to support 97 tokens
Browse files- adaface/adaface_wrapper.py +21 -3
- adaface/util.py +20 -0
- app.py +5 -1
adaface/adaface_wrapper.py
CHANGED
@@ -14,7 +14,7 @@ from diffusers import (
|
|
14 |
LCMScheduler,
|
15 |
)
|
16 |
from diffusers.loaders.single_file_utils import convert_ldm_unet_checkpoint
|
17 |
-
from adaface.util import UNetEnsemble
|
18 |
from adaface.face_id_to_ada_prompt import create_id2ada_prompt_encoder
|
19 |
from adaface.diffusers_attn_lora_capture import set_up_attn_processors, set_up_ffn_loras, set_lora_and_capture_flags
|
20 |
from safetensors.torch import load_file as safetensors_load_file
|
@@ -27,7 +27,7 @@ class AdaFaceWrapper(nn.Module):
|
|
27 |
adaface_ckpt_paths, adaface_encoder_cfg_scales=None,
|
28 |
enabled_encoders=None, use_lcm=False, default_scheduler_name='ddim',
|
29 |
num_inference_steps=50, subject_string='z', negative_prompt=None,
|
30 |
-
use_840k_vae=False, use_ds_text_encoder=False,
|
31 |
main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
|
32 |
enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
|
33 |
attn_lora_layer_names=['q', 'k', 'v', 'out'], normalize_cross_attn=False, q_lora_updates_query=False,
|
@@ -56,6 +56,9 @@ class AdaFaceWrapper(nn.Module):
|
|
56 |
|
57 |
self.default_scheduler_name = default_scheduler_name
|
58 |
self.num_inference_steps = num_inference_steps if not use_lcm else 4
|
|
|
|
|
|
|
59 |
self.use_840k_vae = use_840k_vae
|
60 |
self.use_ds_text_encoder = use_ds_text_encoder
|
61 |
self.main_unet_filepath = main_unet_filepath
|
@@ -199,6 +202,21 @@ class AdaFaceWrapper(nn.Module):
|
|
199 |
|
200 |
pipeline.unet = unet2
|
201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
if self.use_840k_vae:
|
203 |
pipeline.vae = vae
|
204 |
print("Replaced the VAE with the 840k-step VAE.")
|
@@ -715,7 +733,7 @@ class AdaFaceWrapper(nn.Module):
|
|
715 |
ref_img_strength=0.8, generator=None,
|
716 |
ablate_prompt_only_placeholders=False,
|
717 |
ablate_prompt_no_placeholders=False,
|
718 |
-
ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', '
|
719 |
nonmix_prompt_emb_weight=0,
|
720 |
repeat_prompt_for_each_encoder=True,
|
721 |
verbose=False):
|
|
|
14 |
LCMScheduler,
|
15 |
)
|
16 |
from diffusers.loaders.single_file_utils import convert_ldm_unet_checkpoint
|
17 |
+
from adaface.util import UNetEnsemble, extend_nn_embedding
|
18 |
from adaface.face_id_to_ada_prompt import create_id2ada_prompt_encoder
|
19 |
from adaface.diffusers_attn_lora_capture import set_up_attn_processors, set_up_ffn_loras, set_lora_and_capture_flags
|
20 |
from safetensors.torch import load_file as safetensors_load_file
|
|
|
27 |
adaface_ckpt_paths, adaface_encoder_cfg_scales=None,
|
28 |
enabled_encoders=None, use_lcm=False, default_scheduler_name='ddim',
|
29 |
num_inference_steps=50, subject_string='z', negative_prompt=None,
|
30 |
+
max_prompt_length=77, use_840k_vae=False, use_ds_text_encoder=False,
|
31 |
main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
|
32 |
enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
|
33 |
attn_lora_layer_names=['q', 'k', 'v', 'out'], normalize_cross_attn=False, q_lora_updates_query=False,
|
|
|
56 |
|
57 |
self.default_scheduler_name = default_scheduler_name
|
58 |
self.num_inference_steps = num_inference_steps if not use_lcm else 4
|
59 |
+
|
60 |
+
self.max_prompt_length = max_prompt_length
|
61 |
+
|
62 |
self.use_840k_vae = use_840k_vae
|
63 |
self.use_ds_text_encoder = use_ds_text_encoder
|
64 |
self.main_unet_filepath = main_unet_filepath
|
|
|
202 |
|
203 |
pipeline.unet = unet2
|
204 |
|
205 |
+
# Extending prompt length is for SD 1.5 only.
|
206 |
+
if (self.pipeline_name == "text2img") and (self.max_prompt_length > 77):
|
207 |
+
# pipeline.text_encoder.text_model.embeddings.position_embedding.weight: [77, 768] -> [max_length, 768]
|
208 |
+
# We reuse the last EL position embeddings for the new position embeddings.
|
209 |
+
# If we use the "neat" way, i.e., initialize CLIPTextModel with a CLIPTextConfig with
|
210 |
+
# a larger max_position_embeddings, and set ignore_mismatched_sizes=True,
|
211 |
+
# then the old position embeddings won't be loaded from the pretrained ckpt,
|
212 |
+
# leading to degenerated performance.
|
213 |
+
EL = self.max_prompt_length - 77
|
214 |
+
# position_embedding.weight: [77, 768] -> [max_length, 768]
|
215 |
+
new_position_embedding = extend_nn_embedding(pipeline.text_encoder.text_model.embeddings.position_embedding,
|
216 |
+
pipeline.text_encoder.text_model.embeddings.position_embedding.weight[-EL:])
|
217 |
+
pipeline.text_encoder.text_model.embeddings.position_embedding = new_position_embedding
|
218 |
+
pipeline.text_encoder.text_model.embeddings.position_ids = torch.arange(self.max_prompt_length).unsqueeze(0)
|
219 |
+
|
220 |
if self.use_840k_vae:
|
221 |
pipeline.vae = vae
|
222 |
print("Replaced the VAE with the 840k-step VAE.")
|
|
|
733 |
ref_img_strength=0.8, generator=None,
|
734 |
ablate_prompt_only_placeholders=False,
|
735 |
ablate_prompt_no_placeholders=False,
|
736 |
+
ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img1', 'img2'.
|
737 |
nonmix_prompt_emb_weight=0,
|
738 |
repeat_prompt_for_each_encoder=True,
|
739 |
verbose=False):
|
adaface/util.py
CHANGED
@@ -73,6 +73,26 @@ def calc_stats(emb_name, embeddings, mean_dim=-1):
|
|
73 |
print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std()))
|
74 |
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
# Revised from RevGrad, by removing the grad negation.
|
77 |
class ScaleGrad(torch.autograd.Function):
|
78 |
@staticmethod
|
|
|
73 |
print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std()))
|
74 |
|
75 |
|
76 |
+
# new_token_embeddings: [new_num_tokens, 768].
|
77 |
+
def extend_nn_embedding(old_nn_embedding, new_token_embeddings):
|
78 |
+
emb_dim = old_nn_embedding.embedding_dim
|
79 |
+
num_old_tokens = old_nn_embedding.num_embeddings
|
80 |
+
num_new_tokens = new_token_embeddings.shape[0]
|
81 |
+
num_tokens2 = num_old_tokens + num_new_tokens
|
82 |
+
|
83 |
+
new_nn_embedding = nn.Embedding(num_tokens2, emb_dim,
|
84 |
+
device=old_nn_embedding.weight.device,
|
85 |
+
dtype=old_nn_embedding.weight.dtype)
|
86 |
+
|
87 |
+
old_num_tokens = old_nn_embedding.weight.shape[0]
|
88 |
+
# Copy the first old_num_tokens embeddings from old_nn_embedding to new_nn_embedding.
|
89 |
+
new_nn_embedding.weight.data[:old_num_tokens] = old_nn_embedding.weight.data
|
90 |
+
# Copy the new embeddings to new_nn_embedding.
|
91 |
+
new_nn_embedding.weight.data[old_num_tokens:] = new_token_embeddings
|
92 |
+
|
93 |
+
print(f"Extended nn.Embedding from {num_old_tokens} to {num_tokens2} tokens.")
|
94 |
+
return new_nn_embedding
|
95 |
+
|
96 |
# Revised from RevGrad, by removing the grad negation.
|
97 |
class ScaleGrad(torch.autograd.Function):
|
98 |
@staticmethod
|
app.py
CHANGED
@@ -34,6 +34,8 @@ parser.add_argument('--num_inference_steps', type=int, default=50,
|
|
34 |
parser.add_argument('--ablate_prompt_embed_type', type=str, default='ada',
|
35 |
choices=["ada", "arc2face", "consistentID"],
|
36 |
help="Ablate to use the image ID embs instead of Ada embs")
|
|
|
|
|
37 |
|
38 |
parser.add_argument('--gpu', type=int, default=None)
|
39 |
parser.add_argument('--ip', type=str, default="0.0.0.0")
|
@@ -79,6 +81,7 @@ adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=adaface_base_
|
|
79 |
adaface_encoder_types=args.adaface_encoder_types,
|
80 |
adaface_ckpt_paths=args.adaface_ckpt_path, device='cpu',
|
81 |
num_inference_steps=args.num_inference_steps,
|
|
|
82 |
is_on_hf_space=is_on_hf_space)
|
83 |
|
84 |
basedir = os.getcwd()
|
@@ -208,7 +211,7 @@ def generate_video(image_container, uploaded_image_paths, init_img_file_paths, i
|
|
208 |
if args.ablate_prompt_embed_type != "ada":
|
209 |
# Find the prompt_emb_type index in adaface_encoder_types
|
210 |
# adaface_encoder_types: ["consistentID", "arc2face"]
|
211 |
-
ablate_prompt_embed_index = args.adaface_encoder_types.index(args.ablate_prompt_embed_type)
|
212 |
ablate_prompt_embed_type = f"img{ablate_prompt_embed_index}"
|
213 |
else:
|
214 |
ablate_prompt_embed_type = "ada"
|
@@ -270,6 +273,7 @@ def check_prompt_and_model_type(prompt, model_style_type, progress=gr.Progress()
|
|
270 |
adaface_encoder_types=args.adaface_encoder_types,
|
271 |
adaface_ckpt_paths=[args.adaface_ckpt_path], device='cpu',
|
272 |
num_inference_steps=args.num_inference_steps,
|
|
|
273 |
is_on_hf_space=is_on_hf_space)
|
274 |
# Update base model type.
|
275 |
args.model_style_type = model_style_type
|
|
|
34 |
parser.add_argument('--ablate_prompt_embed_type', type=str, default='ada',
|
35 |
choices=["ada", "arc2face", "consistentID"],
|
36 |
help="Ablate to use the image ID embs instead of Ada embs")
|
37 |
+
parser.add_argument('--max_prompt_length', type=int, default=97,
|
38 |
+
help="Maximum length of the prompt. If > 77, the CLIP text encoder will be extended.")
|
39 |
|
40 |
parser.add_argument('--gpu', type=int, default=None)
|
41 |
parser.add_argument('--ip', type=str, default="0.0.0.0")
|
|
|
81 |
adaface_encoder_types=args.adaface_encoder_types,
|
82 |
adaface_ckpt_paths=args.adaface_ckpt_path, device='cpu',
|
83 |
num_inference_steps=args.num_inference_steps,
|
84 |
+
max_prompt_length=args.max_prompt_length,
|
85 |
is_on_hf_space=is_on_hf_space)
|
86 |
|
87 |
basedir = os.getcwd()
|
|
|
211 |
if args.ablate_prompt_embed_type != "ada":
|
212 |
# Find the prompt_emb_type index in adaface_encoder_types
|
213 |
# adaface_encoder_types: ["consistentID", "arc2face"]
|
214 |
+
ablate_prompt_embed_index = args.adaface_encoder_types.index(args.ablate_prompt_embed_type) + 1
|
215 |
ablate_prompt_embed_type = f"img{ablate_prompt_embed_index}"
|
216 |
else:
|
217 |
ablate_prompt_embed_type = "ada"
|
|
|
273 |
adaface_encoder_types=args.adaface_encoder_types,
|
274 |
adaface_ckpt_paths=[args.adaface_ckpt_path], device='cpu',
|
275 |
num_inference_steps=args.num_inference_steps,
|
276 |
+
max_prompt_length=args.max_prompt_length,
|
277 |
is_on_hf_space=is_on_hf_space)
|
278 |
# Update base model type.
|
279 |
args.model_style_type = model_style_type
|