Spaces:
Runtime error
Runtime error
Update ledits/pipeline_leditspp_stable_diffusion_xl.py
Browse files
ledits/pipeline_leditspp_stable_diffusion_xl.py
CHANGED
@@ -613,11 +613,10 @@ class LEditsPPPipelineStableDiffusionXL(
|
|
613 |
else:
|
614 |
# "2" because SDXL always indexes from the penultimate layer.
|
615 |
edit_concepts_embeds = edit_concepts_embeds.hidden_states[-(clip_skip + 2)]
|
616 |
-
|
617 |
-
|
618 |
-
if avg_diff is not None
|
619 |
-
|
620 |
-
print("SHALOM")
|
621 |
normed_prompt_embeds = edit_concepts_embeds / edit_concepts_embeds.norm(dim=-1, keepdim=True)
|
622 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
623 |
if i == 0:
|
@@ -626,14 +625,26 @@ class LEditsPPPipelineStableDiffusionXL(
|
|
626 |
standard_weights = torch.ones_like(weights)
|
627 |
|
628 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
629 |
-
edit_concepts_embeds = edit_concepts_embeds + (
|
|
|
|
|
|
|
|
|
|
|
|
|
630 |
else:
|
631 |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
632 |
|
633 |
standard_weights = torch.ones_like(weights)
|
634 |
|
635 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
636 |
-
edit_concepts_embeds = edit_concepts_embeds + (
|
|
|
|
|
|
|
|
|
|
|
|
|
637 |
|
638 |
edit_prompt_embeds_list.append(edit_concepts_embeds)
|
639 |
i+=1
|
|
|
613 |
else:
|
614 |
# "2" because SDXL always indexes from the penultimate layer.
|
615 |
edit_concepts_embeds = edit_concepts_embeds.hidden_states[-(clip_skip + 2)]
|
616 |
+
|
617 |
+
|
618 |
+
if avg_diff is not None:
|
619 |
+
|
|
|
620 |
normed_prompt_embeds = edit_concepts_embeds / edit_concepts_embeds.norm(dim=-1, keepdim=True)
|
621 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
622 |
if i == 0:
|
|
|
625 |
standard_weights = torch.ones_like(weights)
|
626 |
|
627 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
628 |
+
edit_concepts_embeds = edit_concepts_embeds + (
|
629 |
+
weights * avg_diff[0][None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
|
630 |
+
|
631 |
+
if avg_diff_2nd is not None:
|
632 |
+
edit_concepts_embeds += (weights * avg_diff_2nd[0][None, :].repeat(1,
|
633 |
+
self.pipe.tokenizer.model_max_length,
|
634 |
+
1) * scale_2nd)
|
635 |
else:
|
636 |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
637 |
|
638 |
standard_weights = torch.ones_like(weights)
|
639 |
|
640 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
641 |
+
edit_concepts_embeds = edit_concepts_embeds + (
|
642 |
+
weights * avg_diff[1][None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
|
643 |
+
if avg_diff_2nd is not None:
|
644 |
+
edit_concepts_embeds += (weights * avg_diff_2nd[1][None, :].repeat(1,
|
645 |
+
self.pipe.tokenizer_2.model_max_length,
|
646 |
+
1) * scale_2nd)
|
647 |
+
|
648 |
|
649 |
edit_prompt_embeds_list.append(edit_concepts_embeds)
|
650 |
i+=1
|