Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
7f97ed5
1
Parent(s):
bccf74a
fix bug that prevents clip extension; extend clip from 97 to 147 tokens
Browse files- adaface/adaface_wrapper.py +13 -13
- app.py +1 -1
adaface/adaface_wrapper.py
CHANGED
|
@@ -117,14 +117,6 @@ class AdaFaceWrapper(nn.Module):
|
|
| 117 |
else:
|
| 118 |
vae = None
|
| 119 |
|
| 120 |
-
if self.use_ds_text_encoder:
|
| 121 |
-
# The dreamshaper v7 finetuned text encoder follows the prompt slightly better than the original text encoder.
|
| 122 |
-
# https://huggingface.co/Lykon/DreamShaper/tree/main/text_encoder
|
| 123 |
-
text_encoder = CLIPTextModel.from_pretrained("models/diffusers/ds_text_encoder",
|
| 124 |
-
torch_dtype=torch.float16)
|
| 125 |
-
else:
|
| 126 |
-
text_encoder = None
|
| 127 |
-
|
| 128 |
remove_unet = False
|
| 129 |
|
| 130 |
if self.pipeline_name == "img2img":
|
|
@@ -202,6 +194,13 @@ class AdaFaceWrapper(nn.Module):
|
|
| 202 |
|
| 203 |
pipeline.unet = unet2
|
| 204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
# Extending prompt length is for SD 1.5 only.
|
| 206 |
if (self.pipeline_name == "text2img") and (self.max_prompt_length > 77):
|
| 207 |
# pipeline.text_encoder.text_model.embeddings.position_embedding.weight: [77, 768] -> [max_length, 768]
|
|
@@ -210,20 +209,21 @@ class AdaFaceWrapper(nn.Module):
|
|
| 210 |
# a larger max_position_embeddings, and set ignore_mismatched_sizes=True,
|
| 211 |
# then the old position embeddings won't be loaded from the pretrained ckpt,
|
| 212 |
# leading to degenerated performance.
|
|
|
|
|
|
|
|
|
|
| 213 |
EL = self.max_prompt_length - 77
|
| 214 |
# position_embedding.weight: [77, 768] -> [max_length, 768]
|
| 215 |
new_position_embedding = extend_nn_embedding(pipeline.text_encoder.text_model.embeddings.position_embedding,
|
| 216 |
pipeline.text_encoder.text_model.embeddings.position_embedding.weight[-EL:])
|
| 217 |
pipeline.text_encoder.text_model.embeddings.position_embedding = new_position_embedding
|
| 218 |
pipeline.text_encoder.text_model.embeddings.position_ids = torch.arange(self.max_prompt_length).unsqueeze(0)
|
| 219 |
-
|
|
|
|
|
|
|
| 220 |
if self.use_840k_vae:
|
| 221 |
pipeline.vae = vae
|
| 222 |
print("Replaced the VAE with the 840k-step VAE.")
|
| 223 |
-
|
| 224 |
-
if self.use_ds_text_encoder:
|
| 225 |
-
pipeline.text_encoder = text_encoder
|
| 226 |
-
print("Replaced the text encoder with the DreamShaper text encoder.")
|
| 227 |
|
| 228 |
if remove_unet:
|
| 229 |
# Remove unet and vae to release RAM. Only keep tokenizer and text_encoder.
|
|
|
|
| 117 |
else:
|
| 118 |
vae = None
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
remove_unet = False
|
| 121 |
|
| 122 |
if self.pipeline_name == "img2img":
|
|
|
|
| 194 |
|
| 195 |
pipeline.unet = unet2
|
| 196 |
|
| 197 |
+
if self.use_ds_text_encoder:
|
| 198 |
+
# The dreamshaper v7 finetuned text encoder follows the prompt slightly better than the original text encoder.
|
| 199 |
+
# https://huggingface.co/Lykon/DreamShaper/tree/main/text_encoder
|
| 200 |
+
pipeline.text_encoder = CLIPTextModel.from_pretrained("models/diffusers/ds_text_encoder",
|
| 201 |
+
torch_dtype=torch.float16)
|
| 202 |
+
print("Replaced the text encoder with the DreamShaper text encoder.")
|
| 203 |
+
|
| 204 |
# Extending prompt length is for SD 1.5 only.
|
| 205 |
if (self.pipeline_name == "text2img") and (self.max_prompt_length > 77):
|
| 206 |
# pipeline.text_encoder.text_model.embeddings.position_embedding.weight: [77, 768] -> [max_length, 768]
|
|
|
|
| 209 |
# a larger max_position_embeddings, and set ignore_mismatched_sizes=True,
|
| 210 |
# then the old position embeddings won't be loaded from the pretrained ckpt,
|
| 211 |
# leading to degenerated performance.
|
| 212 |
+
# max_prompt_length <= 77 + 70 = 147.
|
| 213 |
+
self.max_prompt_length = min(self.max_prompt_length, 147)
|
| 214 |
+
# Number of extra tokens is at most 70.
|
| 215 |
EL = self.max_prompt_length - 77
|
| 216 |
# position_embedding.weight: [77, 768] -> [max_length, 768]
|
| 217 |
new_position_embedding = extend_nn_embedding(pipeline.text_encoder.text_model.embeddings.position_embedding,
|
| 218 |
pipeline.text_encoder.text_model.embeddings.position_embedding.weight[-EL:])
|
| 219 |
pipeline.text_encoder.text_model.embeddings.position_embedding = new_position_embedding
|
| 220 |
pipeline.text_encoder.text_model.embeddings.position_ids = torch.arange(self.max_prompt_length).unsqueeze(0)
|
| 221 |
+
pipeline.text_encoder.text_model.config.max_position_embeddings = self.max_prompt_length
|
| 222 |
+
pipeline.tokenizer.model_max_length = self.max_prompt_length
|
| 223 |
+
|
| 224 |
if self.use_840k_vae:
|
| 225 |
pipeline.vae = vae
|
| 226 |
print("Replaced the VAE with the 840k-step VAE.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
if remove_unet:
|
| 229 |
# Remove unet and vae to release RAM. Only keep tokenizer and text_encoder.
|
app.py
CHANGED
|
@@ -34,7 +34,7 @@ parser.add_argument('--num_inference_steps', type=int, default=50,
|
|
| 34 |
parser.add_argument('--ablate_prompt_embed_type', type=str, default='ada',
|
| 35 |
choices=["ada", "arc2face", "consistentID"],
|
| 36 |
help="Ablate to use the image ID embs instead of Ada embs")
|
| 37 |
-
parser.add_argument('--max_prompt_length', type=int, default=
|
| 38 |
help="Maximum length of the prompt. If > 77, the CLIP text encoder will be extended.")
|
| 39 |
|
| 40 |
parser.add_argument('--gpu', type=int, default=None)
|
|
|
|
| 34 |
parser.add_argument('--ablate_prompt_embed_type', type=str, default='ada',
|
| 35 |
choices=["ada", "arc2face", "consistentID"],
|
| 36 |
help="Ablate to use the image ID embs instead of Ada embs")
|
| 37 |
+
parser.add_argument('--max_prompt_length', type=int, default=147,
|
| 38 |
help="Maximum length of the prompt. If > 77, the CLIP text encoder will be extended.")
|
| 39 |
|
| 40 |
parser.add_argument('--gpu', type=int, default=None)
|