Spaces:
Configuration error
Configuration error
wip statewrapper chagne
Browse files- app.py +2 -2
- app_backend.py +7 -19
- loaders.py +1 -0
app.py
CHANGED
|
@@ -19,7 +19,7 @@ from loaders import load_default
|
|
| 19 |
from animation import create_gif
|
| 20 |
from prompts import get_random_prompts
|
| 21 |
|
| 22 |
-
device = "
|
| 23 |
vqgan = load_default(device)
|
| 24 |
vqgan.eval()
|
| 25 |
processor = ProcessorGradientFlow(device=device)
|
|
@@ -62,7 +62,7 @@ class StateWrapper:
|
|
| 62 |
return state, *state[0].update_requant(*args, **kwargs)
|
| 63 |
|
| 64 |
with gr.Blocks(css="styles.css") as demo:
|
| 65 |
-
promptoptim =
|
| 66 |
state = gr.State([ImageState(vqgan, promptoptim)])
|
| 67 |
with gr.Row():
|
| 68 |
with gr.Column(scale=1):
|
|
|
|
| 19 |
from animation import create_gif
|
| 20 |
from prompts import get_random_prompts
|
| 21 |
|
| 22 |
+
device = "cpu"
|
| 23 |
vqgan = load_default(device)
|
| 24 |
vqgan.eval()
|
| 25 |
processor = ProcessorGradientFlow(device=device)
|
|
|
|
| 62 |
return state, *state[0].update_requant(*args, **kwargs)
|
| 63 |
|
| 64 |
with gr.Blocks(css="styles.css") as demo:
|
| 65 |
+
promptoptim = ImagePromptOptimizer(vqgan, clip, processor, quantize=True)
|
| 66 |
state = gr.State([ImageState(vqgan, promptoptim)])
|
| 67 |
with gr.Row():
|
| 68 |
with gr.Column(scale=1):
|
app_backend.py
CHANGED
|
@@ -174,19 +174,13 @@ class ImagePromptOptimizer(nn.Module):
|
|
| 174 |
clip_clone = processed_img.clone()
|
| 175 |
clip_clone.register_hook(self.attn_masking)
|
| 176 |
clip_clone.retain_grad()
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
# with torch.no_grad():
|
| 183 |
-
# disc_logits = self.disc(transformed_img)
|
| 184 |
-
# disc_loss = self.disc_loss_fn(disc_logits)
|
| 185 |
-
# print(f"disc_loss = {disc_loss}")
|
| 186 |
-
# disc_loss2 = self.disc(processed_img)
|
| 187 |
if log:
|
| 188 |
wandb.log({"Perceptual Loss": perceptual_loss})
|
| 189 |
-
# wandb.log({"Discriminator Loss": disc_loss})
|
| 190 |
wandb.log({"CLIP Loss": clip_loss})
|
| 191 |
clip_loss.backward(retain_graph=True)
|
| 192 |
perceptual_loss.backward(retain_graph=True)
|
|
@@ -208,14 +202,8 @@ class ImagePromptOptimizer(nn.Module):
|
|
| 208 |
lpips_input = processed_img.clone()
|
| 209 |
lpips_input.register_hook(self.attn_masking2)
|
| 210 |
lpips_input.retain_grad()
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
# with torch.no_grad():
|
| 214 |
-
# disc_logits = self.disc(transformed_img)
|
| 215 |
-
# disc_loss = self.disc_loss_fn(disc_logits)
|
| 216 |
-
# print(f"disc_loss = {disc_loss}")
|
| 217 |
-
# disc_loss2 = self.disc(processed_img)
|
| 218 |
-
# print(f"disc_loss2 = {disc_loss2}")
|
| 219 |
if log:
|
| 220 |
wandb.log({"Perceptual Loss": perceptual_loss})
|
| 221 |
print("LPIPS loss: ", perceptual_loss)
|
|
|
|
| 174 |
clip_clone = processed_img.clone()
|
| 175 |
clip_clone.register_hook(self.attn_masking)
|
| 176 |
clip_clone.retain_grad()
|
| 177 |
+
with torch.autocast("cuda"):
|
| 178 |
+
clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts, clip_clone)
|
| 179 |
+
print("CLIP loss", clip_loss)
|
| 180 |
+
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
| 181 |
+
print("LPIPS loss: ", perceptual_loss)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
if log:
|
| 183 |
wandb.log({"Perceptual Loss": perceptual_loss})
|
|
|
|
| 184 |
wandb.log({"CLIP Loss": clip_loss})
|
| 185 |
clip_loss.backward(retain_graph=True)
|
| 186 |
perceptual_loss.backward(retain_graph=True)
|
|
|
|
| 202 |
lpips_input = processed_img.clone()
|
| 203 |
lpips_input.register_hook(self.attn_masking2)
|
| 204 |
lpips_input.retain_grad()
|
| 205 |
+
with torch.autocast("cuda"):
|
| 206 |
+
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
if log:
|
| 208 |
wandb.log({"Perceptual Loss": perceptual_loss})
|
| 209 |
print("LPIPS loss: ", perceptual_loss)
|
loaders.py
CHANGED
|
@@ -36,6 +36,7 @@ def load_default(device):
|
|
| 36 |
sd = torch.load("./vqgan_only.pt", map_location=device)
|
| 37 |
model.load_state_dict(sd, strict=True)
|
| 38 |
model.to(device)
|
|
|
|
| 39 |
return model
|
| 40 |
|
| 41 |
|
|
|
|
| 36 |
sd = torch.load("./vqgan_only.pt", map_location=device)
|
| 37 |
model.load_state_dict(sd, strict=True)
|
| 38 |
model.to(device)
|
| 39 |
+
del sd
|
| 40 |
return model
|
| 41 |
|
| 42 |
|