Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -13,18 +13,15 @@ from torch.nn import functional as F
|
|
13 |
from torchvision import transforms
|
14 |
from torchvision.transforms import functional as TF
|
15 |
from tqdm import trange
|
16 |
-
from
|
17 |
-
import
|
18 |
-
from
|
19 |
-
import train_latent_diffusion as train
|
20 |
from huggingface_hub import hf_hub_url, cached_download
|
21 |
import gradio as gr # 🎨 The magic canvas for AI-powered image generation!
|
22 |
|
23 |
-
# 🖼️ Download the necessary model files
|
24 |
-
|
25 |
-
|
26 |
-
ae_model_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt"))
|
27 |
-
ae_config_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml"))
|
28 |
|
29 |
# 📐 Utility Functions: Math and images, what could go wrong?
|
30 |
# These functions help parse prompts and resize/crop images to fit nicely
|
@@ -33,11 +30,7 @@ def parse_prompt(prompt, default_weight=3.):
|
|
33 |
"""
|
34 |
🎯 Parses a prompt into text and weight.
|
35 |
"""
|
36 |
-
|
37 |
-
vals = prompt.rsplit(':', 2)
|
38 |
-
vals = [vals[0] + ':' + vals[1], *vals[2:]]
|
39 |
-
else:
|
40 |
-
vals = prompt.rsplit(':', 1)
|
41 |
vals = vals + ['', default_weight][len(vals):]
|
42 |
return vals[0], float(vals[1])
|
43 |
|
@@ -49,59 +42,51 @@ def resize_and_center_crop(image, size):
|
|
49 |
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
|
50 |
return TF.center_crop(image, size[::-1])
|
51 |
|
52 |
-
|
53 |
# 🧠 Model loading: the brain of our operation! 🔥
|
54 |
-
# Load all the models: autoencoder, diffusion, and CLOOB
|
55 |
|
56 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
57 |
print('Using device:', device)
|
58 |
print('loading models... 🛠️')
|
59 |
|
60 |
-
#
|
61 |
-
|
62 |
-
|
63 |
-
ae_model.eval().requires_grad_(False).to(device)
|
64 |
-
ae_model.load_state_dict(torch.load(ae_model_path))
|
65 |
-
n_ch, side_y, side_x = 4, 32, 32
|
66 |
-
|
67 |
-
# 🌀 Diffusion Model Setup: The artist behind the scenes
|
68 |
-
model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
|
69 |
-
model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
|
70 |
-
model = model.to(device).eval().requires_grad_(False)
|
71 |
|
72 |
-
#
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint))
|
77 |
-
cloob.eval().requires_grad_(False).to(device)
|
78 |
|
|
|
|
|
|
|
|
|
79 |
|
80 |
# 🎨 The key function: Where the magic happens!
|
81 |
# This is where we generate images based on text and image prompts
|
82 |
|
83 |
-
def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method='
|
84 |
"""
|
85 |
🖼️ Generates a list of PIL images based on given text and image prompts.
|
86 |
"""
|
87 |
-
zero_embed = torch.zeros([1,
|
88 |
target_embeds, weights = [zero_embed], []
|
89 |
|
90 |
-
# Parse text prompts
|
91 |
for prompt in prompts:
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
95 |
|
96 |
# Parse image prompts
|
97 |
for prompt in images:
|
98 |
path, weight = parse_prompt(prompt)
|
99 |
-
img = Image.open(
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
target_embeds.append(embed)
|
105 |
weights.append(weight)
|
106 |
|
107 |
# Adjust weights and set seed
|
@@ -115,7 +100,7 @@ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method
|
|
115 |
x_in = x.repeat([n_conds, 1, 1, 1])
|
116 |
t_in = t.repeat([n_conds])
|
117 |
embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
|
118 |
-
vs =
|
119 |
v = vs.mul(weights[:, None, None, None, None]).sum(0)
|
120 |
return v
|
121 |
|
@@ -131,22 +116,19 @@ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method
|
|
131 |
|
132 |
# 🏃♂️ Generate the output images
|
133 |
batch_size = n
|
134 |
-
x = torch.randn([n,
|
135 |
t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
|
136 |
pil_ims = []
|
137 |
for i in trange(0, n, batch_size):
|
138 |
cur_batch_size = min(n - i, batch_size)
|
139 |
out_latents = run(x[i:i + cur_batch_size], steps)
|
140 |
-
outs =
|
141 |
for j, out in enumerate(outs):
|
142 |
-
pil_ims.append(
|
143 |
|
144 |
return pil_ims
|
145 |
|
146 |
-
|
147 |
# 🖌️ Interface: Gradio's brush to paint the UI
|
148 |
-
# Gradio is used here to create a user-friendly interface for art generation.
|
149 |
-
|
150 |
def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
|
151 |
"""
|
152 |
💡 Gradio function to wrap image generation.
|
@@ -169,56 +151,12 @@ iface = gr.Interface(
|
|
169 |
],
|
170 |
outputs=gr.Image(type="pil", label="Generated Image"),
|
171 |
examples=[
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
["Abstract Art, in the style of M.C. Escher"],
|
176 |
-
['Surrealism, in the style of Salvador Dali'],
|
177 |
-
["Romanesque Art, in the style of Leonardo da Vinci"],
|
178 |
-
["landscape"],
|
179 |
-
["portrait"],
|
180 |
-
["sculpture"],
|
181 |
-
["photo"],
|
182 |
-
["figurative"],
|
183 |
-
["illustration"],
|
184 |
-
["still life"],
|
185 |
-
["cityscape"],
|
186 |
-
["marina"],
|
187 |
-
["animal painting"],
|
188 |
-
["graffiti"],
|
189 |
-
["mythological painting"],
|
190 |
-
["battle painting"],
|
191 |
-
["self-portrait"],
|
192 |
-
["Impressionism, oil on canvas"],
|
193 |
-
["Katsushika Hokusai, The Dragon of Smoke Escaping from Mount Fuji"],
|
194 |
-
["Moon Light Sonata by Basuki Abdullah"],
|
195 |
-
["Two Trees by M.C. Escher"],
|
196 |
-
["Futurism, in the style of Wassily Kandinsky"],
|
197 |
-
["Surrealism, in the style of Edgar Degas"],
|
198 |
-
["Expressionism, in the style of Wassily Kandinsky"],
|
199 |
-
["Futurism, in the style of Egon Schiele"],
|
200 |
-
["Cubism, in the style of Gustav Klimt"],
|
201 |
-
["Op Art, in the style of Marc Chagall"],
|
202 |
-
["Romanticism, in the style of M.C. Escher"],
|
203 |
-
["Futurism, in the style of M.C. Escher"],
|
204 |
-
["Mannerism, in the style of Paul Klee"],
|
205 |
-
["High Renaissance, in the style of Rembrandt"],
|
206 |
-
["Magic Realism, in the style of Gustave Dore"],
|
207 |
-
["Realism, in the style of Jean-Michel Basquiat"],
|
208 |
-
["Art Nouveau, in the style of Paul Gauguin"],
|
209 |
-
["Avant-garde, in the style of Pierre-Auguste Renoir"],
|
210 |
-
["Baroque, in the style of Edward Hopper"],
|
211 |
-
["Post-Impressionism, in the style of Wassily Kandinsky"],
|
212 |
-
["Naturalism, in the style of Rene Magritte"],
|
213 |
-
["Constructivism, in the style of Paul Cezanne"],
|
214 |
-
["Abstract Expressionism, in the style of Henri Matisse"],
|
215 |
-
["Pop Art, in the style of Vincent van Gogh"],
|
216 |
-
["Futurism, in the style of Zdzislaw Beksinski"],
|
217 |
-
["Aaron Wacker, oil on canvas"]
|
218 |
],
|
219 |
-
title='
|
220 |
-
description="
|
221 |
-
article='Model used is: [model card](https://huggingface.co/huggan/distill-ccld-wa).'
|
222 |
)
|
223 |
|
224 |
# 🚀 Launch the Gradio interface
|
|
|
13 |
from torchvision import transforms
|
14 |
from torchvision.transforms import functional as TF
|
15 |
from tqdm import trange
|
16 |
+
from transformers import CLIPProcessor, CLIPModel
|
17 |
+
from vqvae import VQVAE2 # Autoencoder replacement
|
18 |
+
from diffusion_models import Diffusion # Swapped Diffusion model for DALL·E 2 based model
|
|
|
19 |
from huggingface_hub import hf_hub_url, cached_download
|
20 |
import gradio as gr # 🎨 The magic canvas for AI-powered image generation!
|
21 |
|
22 |
+
# 🖼️ Download the necessary model files from HuggingFace
|
23 |
+
vqvae_model_path = cached_download(hf_hub_url("huggingface/vqvae-2", filename="vqvae_model.ckpt"))
|
24 |
+
diffusion_model_path = cached_download(hf_hub_url("huggingface/dalle-2", filename="diffusion_model.ckpt"))
|
|
|
|
|
25 |
|
26 |
# 📐 Utility Functions: Math and images, what could go wrong?
|
27 |
# These functions help parse prompts and resize/crop images to fit nicely
|
|
|
30 |
"""
|
31 |
🎯 Parses a prompt into text and weight.
|
32 |
"""
|
33 |
+
vals = prompt.rsplit(':', 1)
|
|
|
|
|
|
|
|
|
34 |
vals = vals + ['', default_weight][len(vals):]
|
35 |
return vals[0], float(vals[1])
|
36 |
|
|
|
42 |
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
|
43 |
return TF.center_crop(image, size[::-1])
|
44 |
|
|
|
45 |
# 🧠 Model loading: the brain of our operation! 🔥
|
|
|
46 |
|
47 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
48 |
print('Using device:', device)
|
49 |
print('loading models... 🛠️')
|
50 |
|
51 |
+
# Load CLIP model
|
52 |
+
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
53 |
+
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
+
# Load VQ-VAE-2 Autoencoder
|
56 |
+
vqvae = VQVAE2()
|
57 |
+
vqvae.load_state_dict(torch.load(vqvae_model_path))
|
58 |
+
vqvae.eval().requires_grad_(False).to(device)
|
|
|
|
|
59 |
|
60 |
+
# Load Diffusion Model
|
61 |
+
diffusion_model = Diffusion()
|
62 |
+
diffusion_model.load_state_dict(torch.load(diffusion_model_path))
|
63 |
+
diffusion_model = diffusion_model.to(device).eval().requires_grad_(False)
|
64 |
|
65 |
# 🎨 The key function: Where the magic happens!
|
66 |
# This is where we generate images based on text and image prompts
|
67 |
|
68 |
+
def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method='ddim', eta=None):
|
69 |
"""
|
70 |
🖼️ Generates a list of PIL images based on given text and image prompts.
|
71 |
"""
|
72 |
+
zero_embed = torch.zeros([1, clip_model.config.projection_dim], device=device)
|
73 |
target_embeds, weights = [zero_embed], []
|
74 |
|
75 |
+
# Parse text prompts and encode with CLIP
|
76 |
for prompt in prompts:
|
77 |
+
inputs = clip_processor(text=prompt, return_tensors="pt").to(device)
|
78 |
+
text_embed = clip_model.get_text_features(**inputs).float()
|
79 |
+
target_embeds.append(text_embed)
|
80 |
+
weights.append(1.0)
|
81 |
|
82 |
# Parse image prompts
|
83 |
for prompt in images:
|
84 |
path, weight = parse_prompt(prompt)
|
85 |
+
img = Image.open(path).convert('RGB')
|
86 |
+
img = resize_and_center_crop(img, (224, 224))
|
87 |
+
inputs = clip_processor(images=img, return_tensors="pt").to(device)
|
88 |
+
image_embed = clip_model.get_image_features(**inputs).float()
|
89 |
+
target_embeds.append(image_embed)
|
|
|
90 |
weights.append(weight)
|
91 |
|
92 |
# Adjust weights and set seed
|
|
|
100 |
x_in = x.repeat([n_conds, 1, 1, 1])
|
101 |
t_in = t.repeat([n_conds])
|
102 |
embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
|
103 |
+
vs = diffusion_model(x_in, t_in, embed_in).view([n_conds, n, *x.shape[1:]])
|
104 |
v = vs.mul(weights[:, None, None, None, None]).sum(0)
|
105 |
return v
|
106 |
|
|
|
116 |
|
117 |
# 🏃♂️ Generate the output images
|
118 |
batch_size = n
|
119 |
+
x = torch.randn([n, 3, 64, 64], device=device)
|
120 |
t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
|
121 |
pil_ims = []
|
122 |
for i in trange(0, n, batch_size):
|
123 |
cur_batch_size = min(n - i, batch_size)
|
124 |
out_latents = run(x[i:i + cur_batch_size], steps)
|
125 |
+
outs = vqvae.decode(out_latents)
|
126 |
for j, out in enumerate(outs):
|
127 |
+
pil_ims.append(transforms.ToPILImage()(out))
|
128 |
|
129 |
return pil_ims
|
130 |
|
|
|
131 |
# 🖌️ Interface: Gradio's brush to paint the UI
|
|
|
|
|
132 |
def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
|
133 |
"""
|
134 |
💡 Gradio function to wrap image generation.
|
|
|
151 |
],
|
152 |
outputs=gr.Image(type="pil", label="Generated Image"),
|
153 |
examples=[
|
154 |
+
["A beautiful sunset over the ocean"],
|
155 |
+
["A futuristic cityscape at night"],
|
156 |
+
["A surreal dream-like landscape"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
],
|
158 |
+
title='CLIP + Diffusion Model Image Generator',
|
159 |
+
description="Generate stunning images from text and image prompts using CLIP and a diffusion model.",
|
|
|
160 |
)
|
161 |
|
162 |
# 🚀 Launch the Gradio interface
|