Update app.py
Browse files
app.py
CHANGED
|
@@ -14,8 +14,8 @@ from torchvision import transforms
|
|
| 14 |
from torchvision.transforms import functional as TF
|
| 15 |
from tqdm import trange
|
| 16 |
from transformers import CLIPProcessor, CLIPModel
|
| 17 |
-
from huggingface_hub import hf_hub_download
|
| 18 |
-
import gradio as gr
|
| 19 |
import math
|
| 20 |
|
| 21 |
# -----------------------------------------------------------------------------
|
|
@@ -27,10 +27,6 @@ import math
|
|
| 27 |
class VQVAE2(nn.Module):
|
| 28 |
def __init__(self, n_embed=8192, embed_dim=256, ch=128):
|
| 29 |
super().__init__()
|
| 30 |
-
# This is a simplified placeholder. The actual architecture would be more complex.
|
| 31 |
-
# The key is having a 'decode' method that matches the state_dict.
|
| 32 |
-
# A full implementation would require the original model's architecture file.
|
| 33 |
-
# For this fix, we assume a basic structure that allows loading the state_dict.
|
| 34 |
self.decoder = nn.Sequential(
|
| 35 |
nn.Conv2d(embed_dim, ch * 4, 3, padding=1),
|
| 36 |
nn.ReLU(),
|
|
@@ -40,40 +36,32 @@ class VQVAE2(nn.Module):
|
|
| 40 |
nn.ReLU(),
|
| 41 |
nn.ConvTranspose2d(ch, 3, 4, stride=2, padding=1),
|
| 42 |
)
|
| 43 |
-
|
| 44 |
def decode(self, latents):
|
| 45 |
-
# A real VQVAE would involve lookup tables, but for generation we only need the decoder part.
|
| 46 |
-
# This part is highly dependent on the model checkpoint.
|
| 47 |
-
# The following is a guess to make it runnable, assuming latents are ready for the decoder.
|
| 48 |
return self.decoder(latents)
|
| 49 |
|
| 50 |
# Diffusion Model Definition
|
| 51 |
class Diffusion(nn.Module):
|
| 52 |
-
|
|
|
|
| 53 |
super().__init__()
|
| 54 |
-
# This is also a placeholder for the architecture.
|
| 55 |
-
# A full UNet-style model is expected here. The key is that it can be called
|
| 56 |
-
# with x, t, and conditional embeddings, and returns the predicted noise.
|
| 57 |
self.time_embed = nn.Embedding(1000, n_inputs * 4)
|
| 58 |
self.cond_embed = nn.Linear(n_embed, n_inputs * 4)
|
| 59 |
-
|
| 60 |
self.layers = nn.ModuleList([
|
| 61 |
-
nn.TransformerEncoderLayer(d_model=n_inputs*4, nhead=n_head, dim_feedforward=2048, dropout=0.1, activation='gelu')
|
| 62 |
for _ in range(n_layer)
|
| 63 |
])
|
| 64 |
self.out = nn.Linear(n_inputs*4, n_inputs)
|
| 65 |
|
| 66 |
def forward(self, x, t, c):
|
| 67 |
-
# A very simplified forward pass
|
| 68 |
-
# The actual model is likely a UNet with cross-attention.
|
| 69 |
bs, ch, h, w = x.shape
|
| 70 |
x = x.permute(0, 2, 3, 1).reshape(bs, h * w, ch)
|
| 71 |
-
|
| 72 |
t_emb = self.time_embed(t.long())
|
| 73 |
c_emb = self.cond_embed(c)
|
| 74 |
emb = t_emb + c_emb
|
| 75 |
-
|
| 76 |
-
# This is a gross simplification; a real model would use cross-attention here.
|
| 77 |
x_out = self.out(x + emb.unsqueeze(1))
|
| 78 |
x_out = x_out.reshape(bs, h, w, ch).permute(0, 3, 1, 2)
|
| 79 |
return x_out
|
|
@@ -81,42 +69,35 @@ class Diffusion(nn.Module):
|
|
| 81 |
|
| 82 |
# Sampling Function Definitions
|
| 83 |
def get_sigmas(n_steps):
|
| 84 |
-
"""Returns the sigma schedule."""
|
| 85 |
t = torch.linspace(1, 0, n_steps + 1)
|
| 86 |
return ((t[:-1] ** 2) / (t[1:] ** 2) - 1).sqrt()
|
| 87 |
|
| 88 |
@torch.no_grad()
|
| 89 |
def plms_sample(model, x, steps, **kwargs):
|
| 90 |
-
"""Poor Man's LMS Sampler"""
|
| 91 |
ts = x.new_ones([x.shape[0]])
|
| 92 |
sigmas = get_sigmas(steps)
|
| 93 |
model_fn = lambda x, t: model(x, t * 1000, **kwargs)
|
| 94 |
-
|
| 95 |
-
x_outs = []
|
| 96 |
old_denoised = None
|
| 97 |
-
|
| 98 |
for i in trange(len(sigmas) -1, disable=True):
|
| 99 |
denoised = model_fn(x, ts * sigmas[i])
|
| 100 |
-
|
| 101 |
if old_denoised is None:
|
| 102 |
d = (denoised - x) / sigmas[i]
|
| 103 |
else:
|
| 104 |
-
d = (3 * denoised - old_denoised) / 2 - x / sigmas[i]
|
| 105 |
|
| 106 |
x = x + d * (sigmas[i+1] - sigmas[i])
|
| 107 |
old_denoised = denoised
|
| 108 |
-
x_outs.append(x)
|
| 109 |
-
return x_outs[-1]
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
| 113 |
def ddim_sample(model, x, steps, eta, **kwargs):
|
| 114 |
-
# This is a placeholder for a full DDIM implementation
|
| 115 |
print("Warning: DDIM sampler is not fully implemented. Using PLMS instead.")
|
| 116 |
return plms_sample(model, x, steps, **kwargs)
|
| 117 |
|
| 118 |
def ddpm_sample(model, x, steps, **kwargs):
|
| 119 |
-
# This is a placeholder for a full DDPM implementation
|
| 120 |
print("Warning: DDPM sampler is not fully implemented. Using PLMS instead.")
|
| 121 |
return plms_sample(model, x, steps, **kwargs)
|
| 122 |
|
|
@@ -125,16 +106,12 @@ def ddpm_sample(model, x, steps, **kwargs):
|
|
| 125 |
# -----------------------------------------------------------------------------
|
| 126 |
|
| 127 |
# 🖼️ Download the necessary model files from HuggingFace
|
| 128 |
-
# NOTE: The HuggingFace URLs you provided might be placeholders.
|
| 129 |
-
# Make sure these point to the correct model files.
|
| 130 |
try:
|
| 131 |
-
# FIXED: Using the new hf_hub_download function with keyword arguments
|
| 132 |
vqvae_model_path = hf_hub_download(repo_id="dalle-mini/vqgan_imagenet_f16_16384", filename="flax_model.msgpack")
|
| 133 |
diffusion_model_path = hf_hub_download(repo_id="huggingface/dalle-2", filename="diffusion_model.ckpt")
|
| 134 |
except Exception as e:
|
| 135 |
print(f"Could not download models. Please ensure the HuggingFace URLs are correct.")
|
| 136 |
print("Using placeholder models which will not produce good images.")
|
| 137 |
-
# Create dummy files if download fails to allow script to run
|
| 138 |
Path("vqvae_model.ckpt").touch()
|
| 139 |
Path("diffusion_model.ckpt").touch()
|
| 140 |
vqvae_model_path = "vqvae_model.ckpt"
|
|
@@ -142,26 +119,17 @@ except Exception as e:
|
|
| 142 |
|
| 143 |
|
| 144 |
# 📐 Utility Functions: Math and images, what could go wrong?
|
| 145 |
-
# These functions help parse prompts and resize/crop images to fit nicely
|
| 146 |
-
|
| 147 |
def parse_prompt(prompt, default_weight=3.):
|
| 148 |
-
"""
|
| 149 |
-
🎯 Parses a prompt into text and weight.
|
| 150 |
-
"""
|
| 151 |
vals = prompt.rsplit(':', 1)
|
| 152 |
vals = vals + ['', default_weight][len(vals):]
|
| 153 |
return vals[0], float(vals[1])
|
| 154 |
|
| 155 |
def resize_and_center_crop(image, size):
|
| 156 |
-
"""
|
| 157 |
-
✂️ Resize and crop image to center it beautifully.
|
| 158 |
-
"""
|
| 159 |
fac = max(size[0] / image.size[0], size[1] / image.size[1])
|
| 160 |
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
|
| 161 |
return TF.center_crop(image, size[::-1])
|
| 162 |
|
| 163 |
# 🧠 Model loading: the brain of our operation! 🔥
|
| 164 |
-
|
| 165 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 166 |
print('Using device:', device)
|
| 167 |
print('loading models... 🛠️')
|
|
@@ -171,23 +139,17 @@ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device
|
|
| 171 |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 172 |
|
| 173 |
# Load VQ-VAE-2 Autoencoder
|
| 174 |
-
# NOTE: The VQVAE2 class is a placeholder. Loading a real checkpoint will likely fail
|
| 175 |
-
# unless the class definition perfectly matches the architecture of the saved model.
|
| 176 |
try:
|
| 177 |
vqvae = VQVAE2()
|
| 178 |
-
# vqvae.load_state_dict(torch.load(vqvae_model_path, map_location=device))
|
| 179 |
print("Skipping VQVAE weight loading due to placeholder architecture.")
|
| 180 |
except Exception as e:
|
| 181 |
print(f"Could not load VQVAE model: {e}. Using placeholder.")
|
| 182 |
vqvae = VQVAE2()
|
| 183 |
vqvae.eval().requires_grad_(False).to(device)
|
| 184 |
|
| 185 |
-
|
| 186 |
# Load Diffusion Model
|
| 187 |
-
# NOTE: The Diffusion class is a placeholder. This will also likely fail.
|
| 188 |
try:
|
| 189 |
diffusion_model = Diffusion()
|
| 190 |
-
# diffusion_model.load_state_dict(torch.load(diffusion_model_path, map_location=device))
|
| 191 |
print("Skipping Diffusion Model weight loading due to placeholder architecture.")
|
| 192 |
except Exception as e:
|
| 193 |
print(f"Could not load Diffusion model: {e}. Using placeholder.")
|
|
@@ -196,27 +158,19 @@ diffusion_model = diffusion_model.to(device).eval().requires_grad_(False)
|
|
| 196 |
|
| 197 |
|
| 198 |
# 🎨 The key function: Where the magic happens!
|
| 199 |
-
# This is where we generate images based on text and image prompts
|
| 200 |
-
|
| 201 |
def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method='ddim', eta=None):
|
| 202 |
-
"""
|
| 203 |
-
🖼️ Generates a list of PIL images based on given text and image prompts.
|
| 204 |
-
"""
|
| 205 |
zero_embed = torch.zeros([1, clip_model.config.projection_dim], device=device)
|
| 206 |
target_embeds, weights = [zero_embed], []
|
| 207 |
|
| 208 |
-
# Parse text prompts and encode with CLIP
|
| 209 |
for prompt in prompts:
|
| 210 |
inputs = clip_processor(text=prompt, return_tensors="pt").to(device)
|
| 211 |
text_embed = clip_model.get_text_features(**inputs).float()
|
| 212 |
target_embeds.append(text_embed)
|
| 213 |
weights.append(1.0)
|
| 214 |
|
| 215 |
-
# Correctly process image prompts from Gradio
|
| 216 |
-
# Assign a default weight for image prompts
|
| 217 |
image_prompt_weight = 1.0
|
| 218 |
for image_path in images:
|
| 219 |
-
if image_path:
|
| 220 |
try:
|
| 221 |
img = Image.open(image_path).convert('RGB')
|
| 222 |
img = resize_and_center_crop(img, (224, 224))
|
|
@@ -227,28 +181,23 @@ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method
|
|
| 227 |
except Exception as e:
|
| 228 |
print(f"Warning: Could not process image prompt {image_path}. Error: {e}")
|
| 229 |
|
| 230 |
-
|
| 231 |
-
# Adjust weights and set seed
|
| 232 |
weights = torch.tensor([1 - sum(weights), *weights], device=device)
|
| 233 |
torch.manual_seed(seed)
|
| 234 |
|
| 235 |
-
# 💡 Model function with classifier-free guidance
|
| 236 |
def cfg_model_fn(x, t):
|
| 237 |
n = x.shape[0]
|
| 238 |
n_conds = len(target_embeds)
|
| 239 |
x_in = x.repeat([n_conds, 1, 1, 1])
|
| 240 |
t_in = t.repeat([n_conds])
|
| 241 |
embed_in = torch.cat(target_embeds).repeat_interleave(n, 0)
|
| 242 |
-
|
| 243 |
-
# Ensure correct dimensions for the placeholder Diffusion model
|
| 244 |
if isinstance(diffusion_model, Diffusion):
|
| 245 |
-
embed_in = embed_in[:, :512]
|
| 246 |
-
|
| 247 |
vs = diffusion_model(x_in, t_in, embed_in).view([n_conds, n, *x.shape[1:]])
|
| 248 |
v = vs.mul(weights[:, None, None, None, None]).sum(0)
|
| 249 |
return v
|
| 250 |
|
| 251 |
-
# 🎞️ Run the sampler to generate images
|
| 252 |
def run(x, steps):
|
| 253 |
if method == 'ddpm':
|
| 254 |
return ddpm_sample(cfg_model_fn, x, steps)
|
|
@@ -258,52 +207,39 @@ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method
|
|
| 258 |
return plms_sample(cfg_model_fn, x, steps)
|
| 259 |
assert False, f"Unknown method: {method}"
|
| 260 |
|
| 261 |
-
# 🏃♂️ Generate the output images
|
| 262 |
batch_size = n
|
| 263 |
x = torch.randn([n, 3, 64, 64], device=device)
|
| 264 |
-
|
| 265 |
pil_ims = []
|
| 266 |
for i in trange(0, n, batch_size):
|
| 267 |
cur_batch_size = min(n - i, batch_size)
|
| 268 |
out_latents = run(x[i:i + cur_batch_size], steps)
|
| 269 |
-
|
| 270 |
-
# The VQVAE expects specific dimensions. Adjusting for the placeholder.
|
| 271 |
-
# This will likely need tuning for the real model.
|
| 272 |
if isinstance(vqvae, VQVAE2):
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
quant_guess = F.gumbel_softmax(out_latents, hard=True).permute(0, 2, 3, 1) # (B, H, W, C)
|
| 277 |
-
pil_ims.append(transforms.ToPILImage()(quant_guess[0].permute(2, 0, 1)))
|
| 278 |
else:
|
| 279 |
outs = vqvae.decode(out_latents)
|
| 280 |
for j, out in enumerate(outs):
|
| 281 |
pil_ims.append(transforms.ToPILImage()(out.clamp(0, 1)))
|
| 282 |
-
|
| 283 |
return pil_ims
|
| 284 |
|
| 285 |
# 🖌️ Interface: Gradio's brush to paint the UI
|
| 286 |
def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
|
| 287 |
-
"""
|
| 288 |
-
💡 Gradio function to wrap image generation.
|
| 289 |
-
"""
|
| 290 |
if seed is None:
|
| 291 |
seed = random.randint(0, 10000)
|
| 292 |
prompts = [prompt]
|
| 293 |
im_prompts = []
|
| 294 |
if im_prompt is not None:
|
| 295 |
im_prompts = [im_prompt]
|
| 296 |
-
|
| 297 |
try:
|
| 298 |
pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
|
| 299 |
return pil_ims[0]
|
| 300 |
except Exception as e:
|
| 301 |
print(f"ERROR during generation: {e}")
|
| 302 |
-
# Return a blank image on failure
|
| 303 |
return Image.new('RGB', (256, 256), color = 'red')
|
| 304 |
|
| 305 |
-
|
| 306 |
-
# 🖼️ Gradio UI: The interface where users can input text or image prompts
|
| 307 |
iface = gr.Interface(
|
| 308 |
fn=gen_ims,
|
| 309 |
inputs=[
|
|
|
|
| 14 |
from torchvision.transforms import functional as TF
|
| 15 |
from tqdm import trange
|
| 16 |
from transformers import CLIPProcessor, CLIPModel
|
| 17 |
+
from huggingface_hub import hf_hub_download
|
| 18 |
+
import gradio as gr
|
| 19 |
import math
|
| 20 |
|
| 21 |
# -----------------------------------------------------------------------------
|
|
|
|
| 27 |
class VQVAE2(nn.Module):
|
| 28 |
def __init__(self, n_embed=8192, embed_dim=256, ch=128):
|
| 29 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
self.decoder = nn.Sequential(
|
| 31 |
nn.Conv2d(embed_dim, ch * 4, 3, padding=1),
|
| 32 |
nn.ReLU(),
|
|
|
|
| 36 |
nn.ReLU(),
|
| 37 |
nn.ConvTranspose2d(ch, 3, 4, stride=2, padding=1),
|
| 38 |
)
|
| 39 |
+
|
| 40 |
def decode(self, latents):
|
|
|
|
|
|
|
|
|
|
| 41 |
return self.decoder(latents)
|
| 42 |
|
| 43 |
# Diffusion Model Definition
|
| 44 |
class Diffusion(nn.Module):
|
| 45 |
+
# FIXED: Changed n_head default from 8 to 4 to make dimensions compatible
|
| 46 |
+
def __init__(self, n_inputs=3, n_embed=512, n_head=4, n_layer=12):
|
| 47 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
| 48 |
self.time_embed = nn.Embedding(1000, n_inputs * 4)
|
| 49 |
self.cond_embed = nn.Linear(n_embed, n_inputs * 4)
|
| 50 |
+
|
| 51 |
self.layers = nn.ModuleList([
|
| 52 |
+
nn.TransformerEncoderLayer(d_model=n_inputs*4, nhead=n_head, dim_feedforward=2048, dropout=0.1, activation='gelu', batch_first=True)
|
| 53 |
for _ in range(n_layer)
|
| 54 |
])
|
| 55 |
self.out = nn.Linear(n_inputs*4, n_inputs)
|
| 56 |
|
| 57 |
def forward(self, x, t, c):
|
|
|
|
|
|
|
| 58 |
bs, ch, h, w = x.shape
|
| 59 |
x = x.permute(0, 2, 3, 1).reshape(bs, h * w, ch)
|
| 60 |
+
|
| 61 |
t_emb = self.time_embed(t.long())
|
| 62 |
c_emb = self.cond_embed(c)
|
| 63 |
emb = t_emb + c_emb
|
| 64 |
+
|
|
|
|
| 65 |
x_out = self.out(x + emb.unsqueeze(1))
|
| 66 |
x_out = x_out.reshape(bs, h, w, ch).permute(0, 3, 1, 2)
|
| 67 |
return x_out
|
|
|
|
| 69 |
|
| 70 |
# Sampling Function Definitions
|
| 71 |
def get_sigmas(n_steps):
|
|
|
|
| 72 |
t = torch.linspace(1, 0, n_steps + 1)
|
| 73 |
return ((t[:-1] ** 2) / (t[1:] ** 2) - 1).sqrt()
|
| 74 |
|
| 75 |
@torch.no_grad()
|
| 76 |
def plms_sample(model, x, steps, **kwargs):
|
|
|
|
| 77 |
ts = x.new_ones([x.shape[0]])
|
| 78 |
sigmas = get_sigmas(steps)
|
| 79 |
model_fn = lambda x, t: model(x, t * 1000, **kwargs)
|
| 80 |
+
|
|
|
|
| 81 |
old_denoised = None
|
| 82 |
+
|
| 83 |
for i in trange(len(sigmas) -1, disable=True):
|
| 84 |
denoised = model_fn(x, ts * sigmas[i])
|
| 85 |
+
|
| 86 |
if old_denoised is None:
|
| 87 |
d = (denoised - x) / sigmas[i]
|
| 88 |
else:
|
| 89 |
+
d = (3 * denoised - old_denoised) / 2 - x / sigmas[i]
|
| 90 |
|
| 91 |
x = x + d * (sigmas[i+1] - sigmas[i])
|
| 92 |
old_denoised = denoised
|
|
|
|
|
|
|
| 93 |
|
| 94 |
+
return x
|
| 95 |
+
|
| 96 |
def ddim_sample(model, x, steps, eta, **kwargs):
|
|
|
|
| 97 |
print("Warning: DDIM sampler is not fully implemented. Using PLMS instead.")
|
| 98 |
return plms_sample(model, x, steps, **kwargs)
|
| 99 |
|
| 100 |
def ddpm_sample(model, x, steps, **kwargs):
|
|
|
|
| 101 |
print("Warning: DDPM sampler is not fully implemented. Using PLMS instead.")
|
| 102 |
return plms_sample(model, x, steps, **kwargs)
|
| 103 |
|
|
|
|
| 106 |
# -----------------------------------------------------------------------------
|
| 107 |
|
| 108 |
# 🖼️ Download the necessary model files from HuggingFace
|
|
|
|
|
|
|
| 109 |
try:
|
|
|
|
| 110 |
vqvae_model_path = hf_hub_download(repo_id="dalle-mini/vqgan_imagenet_f16_16384", filename="flax_model.msgpack")
|
| 111 |
diffusion_model_path = hf_hub_download(repo_id="huggingface/dalle-2", filename="diffusion_model.ckpt")
|
| 112 |
except Exception as e:
|
| 113 |
print(f"Could not download models. Please ensure the HuggingFace URLs are correct.")
|
| 114 |
print("Using placeholder models which will not produce good images.")
|
|
|
|
| 115 |
Path("vqvae_model.ckpt").touch()
|
| 116 |
Path("diffusion_model.ckpt").touch()
|
| 117 |
vqvae_model_path = "vqvae_model.ckpt"
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
# 📐 Utility Functions: Math and images, what could go wrong?
|
|
|
|
|
|
|
| 122 |
def parse_prompt(prompt, default_weight=3.):
|
|
|
|
|
|
|
|
|
|
| 123 |
vals = prompt.rsplit(':', 1)
|
| 124 |
vals = vals + ['', default_weight][len(vals):]
|
| 125 |
return vals[0], float(vals[1])
|
| 126 |
|
| 127 |
def resize_and_center_crop(image, size):
|
|
|
|
|
|
|
|
|
|
| 128 |
fac = max(size[0] / image.size[0], size[1] / image.size[1])
|
| 129 |
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
|
| 130 |
return TF.center_crop(image, size[::-1])
|
| 131 |
|
| 132 |
# 🧠 Model loading: the brain of our operation! 🔥
|
|
|
|
| 133 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 134 |
print('Using device:', device)
|
| 135 |
print('loading models... 🛠️')
|
|
|
|
| 139 |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 140 |
|
| 141 |
# Load VQ-VAE-2 Autoencoder
|
|
|
|
|
|
|
| 142 |
try:
|
| 143 |
vqvae = VQVAE2()
|
|
|
|
| 144 |
print("Skipping VQVAE weight loading due to placeholder architecture.")
|
| 145 |
except Exception as e:
|
| 146 |
print(f"Could not load VQVAE model: {e}. Using placeholder.")
|
| 147 |
vqvae = VQVAE2()
|
| 148 |
vqvae.eval().requires_grad_(False).to(device)
|
| 149 |
|
|
|
|
| 150 |
# Load Diffusion Model
|
|
|
|
| 151 |
try:
|
| 152 |
diffusion_model = Diffusion()
|
|
|
|
| 153 |
print("Skipping Diffusion Model weight loading due to placeholder architecture.")
|
| 154 |
except Exception as e:
|
| 155 |
print(f"Could not load Diffusion model: {e}. Using placeholder.")
|
|
|
|
| 158 |
|
| 159 |
|
| 160 |
# 🎨 The key function: Where the magic happens!
|
|
|
|
|
|
|
| 161 |
def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method='ddim', eta=None):
|
|
|
|
|
|
|
|
|
|
| 162 |
zero_embed = torch.zeros([1, clip_model.config.projection_dim], device=device)
|
| 163 |
target_embeds, weights = [zero_embed], []
|
| 164 |
|
|
|
|
| 165 |
for prompt in prompts:
|
| 166 |
inputs = clip_processor(text=prompt, return_tensors="pt").to(device)
|
| 167 |
text_embed = clip_model.get_text_features(**inputs).float()
|
| 168 |
target_embeds.append(text_embed)
|
| 169 |
weights.append(1.0)
|
| 170 |
|
|
|
|
|
|
|
| 171 |
image_prompt_weight = 1.0
|
| 172 |
for image_path in images:
|
| 173 |
+
if image_path:
|
| 174 |
try:
|
| 175 |
img = Image.open(image_path).convert('RGB')
|
| 176 |
img = resize_and_center_crop(img, (224, 224))
|
|
|
|
| 181 |
except Exception as e:
|
| 182 |
print(f"Warning: Could not process image prompt {image_path}. Error: {e}")
|
| 183 |
|
|
|
|
|
|
|
| 184 |
weights = torch.tensor([1 - sum(weights), *weights], device=device)
|
| 185 |
torch.manual_seed(seed)
|
| 186 |
|
|
|
|
| 187 |
def cfg_model_fn(x, t):
|
| 188 |
n = x.shape[0]
|
| 189 |
n_conds = len(target_embeds)
|
| 190 |
x_in = x.repeat([n_conds, 1, 1, 1])
|
| 191 |
t_in = t.repeat([n_conds])
|
| 192 |
embed_in = torch.cat(target_embeds).repeat_interleave(n, 0)
|
| 193 |
+
|
|
|
|
| 194 |
if isinstance(diffusion_model, Diffusion):
|
| 195 |
+
embed_in = embed_in[:, :512]
|
| 196 |
+
|
| 197 |
vs = diffusion_model(x_in, t_in, embed_in).view([n_conds, n, *x.shape[1:]])
|
| 198 |
v = vs.mul(weights[:, None, None, None, None]).sum(0)
|
| 199 |
return v
|
| 200 |
|
|
|
|
| 201 |
def run(x, steps):
|
| 202 |
if method == 'ddpm':
|
| 203 |
return ddpm_sample(cfg_model_fn, x, steps)
|
|
|
|
| 207 |
return plms_sample(cfg_model_fn, x, steps)
|
| 208 |
assert False, f"Unknown method: {method}"
|
| 209 |
|
|
|
|
| 210 |
batch_size = n
|
| 211 |
x = torch.randn([n, 3, 64, 64], device=device)
|
|
|
|
| 212 |
pil_ims = []
|
| 213 |
for i in trange(0, n, batch_size):
|
| 214 |
cur_batch_size = min(n - i, batch_size)
|
| 215 |
out_latents = run(x[i:i + cur_batch_size], steps)
|
| 216 |
+
|
|
|
|
|
|
|
| 217 |
if isinstance(vqvae, VQVAE2):
|
| 218 |
+
outs = vqvae.decode(out_latents)
|
| 219 |
+
for j, out in enumerate(outs):
|
| 220 |
+
pil_ims.append(transforms.ToPILImage()(out.clamp(0, 1)))
|
|
|
|
|
|
|
| 221 |
else:
|
| 222 |
outs = vqvae.decode(out_latents)
|
| 223 |
for j, out in enumerate(outs):
|
| 224 |
pil_ims.append(transforms.ToPILImage()(out.clamp(0, 1)))
|
|
|
|
| 225 |
return pil_ims
|
| 226 |
|
| 227 |
# 🖌️ Interface: Gradio's brush to paint the UI
|
| 228 |
def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
|
|
|
|
|
|
|
|
|
|
| 229 |
if seed is None:
|
| 230 |
seed = random.randint(0, 10000)
|
| 231 |
prompts = [prompt]
|
| 232 |
im_prompts = []
|
| 233 |
if im_prompt is not None:
|
| 234 |
im_prompts = [im_prompt]
|
|
|
|
| 235 |
try:
|
| 236 |
pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
|
| 237 |
return pil_ims[0]
|
| 238 |
except Exception as e:
|
| 239 |
print(f"ERROR during generation: {e}")
|
|
|
|
| 240 |
return Image.new('RGB', (256, 256), color = 'red')
|
| 241 |
|
| 242 |
+
# 🖼️ Gradio UI
|
|
|
|
| 243 |
iface = gr.Interface(
|
| 244 |
fn=gen_ims,
|
| 245 |
inputs=[
|