awacke1 commited on
Commit
a3dfcb0
1 Parent(s): 3836db0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -101
app.py CHANGED
@@ -13,18 +13,15 @@ from torch.nn import functional as F
13
  from torchvision import transforms
14
  from torchvision.transforms import functional as TF
15
  from tqdm import trange
16
- from cloob_training import model_pt, pretrained
17
- import ldm.models.autoencoder
18
- from diffusion import sampling, utils
19
- import train_latent_diffusion as train
20
  from huggingface_hub import hf_hub_url, cached_download
21
  import gradio as gr # 🎨 The magic canvas for AI-powered image generation!
22
 
23
- # 🖼️ Download the necessary model files
24
- # These files are loaded from HuggingFace's repository
25
- checkpoint = cached_download(hf_hub_url("huggan/distill-ccld-wa", filename="model_student.ckpt"))
26
- ae_model_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt"))
27
- ae_config_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml"))
28
 
29
  # 📐 Utility Functions: Math and images, what could go wrong?
30
  # These functions help parse prompts and resize/crop images to fit nicely
@@ -33,11 +30,7 @@ def parse_prompt(prompt, default_weight=3.):
33
  """
34
  🎯 Parses a prompt into text and weight.
35
  """
36
- if prompt.startswith('http://') or prompt.startswith('https://'):
37
- vals = prompt.rsplit(':', 2)
38
- vals = [vals[0] + ':' + vals[1], *vals[2:]]
39
- else:
40
- vals = prompt.rsplit(':', 1)
41
  vals = vals + ['', default_weight][len(vals):]
42
  return vals[0], float(vals[1])
43
 
@@ -49,59 +42,51 @@ def resize_and_center_crop(image, size):
49
  image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
50
  return TF.center_crop(image, size[::-1])
51
 
52
-
53
  # 🧠 Model loading: the brain of our operation! 🔥
54
- # Load all the models: autoencoder, diffusion, and CLOOB
55
 
56
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
57
  print('Using device:', device)
58
  print('loading models... 🛠️')
59
 
60
- # 🔧 Autoencoder Setup: Let’s decode the madness into images
61
- ae_config = OmegaConf.load(ae_config_path)
62
- ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params)
63
- ae_model.eval().requires_grad_(False).to(device)
64
- ae_model.load_state_dict(torch.load(ae_model_path))
65
- n_ch, side_y, side_x = 4, 32, 32
66
-
67
- # 🌀 Diffusion Model Setup: The artist behind the scenes
68
- model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
69
- model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
70
- model = model.to(device).eval().requires_grad_(False)
71
 
72
- # 👁️ CLOOB Setup: Our vision model to understand art in human style
73
- cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
74
- cloob = model_pt.get_pt_model(cloob_config)
75
- checkpoint = pretrained.download_checkpoint(cloob_config)
76
- cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint))
77
- cloob.eval().requires_grad_(False).to(device)
78
 
 
 
 
 
79
 
80
  # 🎨 The key function: Where the magic happens!
81
  # This is where we generate images based on text and image prompts
82
 
83
- def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method='plms', eta=None):
84
  """
85
  🖼️ Generates a list of PIL images based on given text and image prompts.
86
  """
87
- zero_embed = torch.zeros([1, cloob.config['d_embed']], device=device)
88
  target_embeds, weights = [zero_embed], []
89
 
90
- # Parse text prompts
91
  for prompt in prompts:
92
- txt, weight = parse_prompt(prompt)
93
- target_embeds.append(cloob.text_encoder(cloob.tokenize(txt).to(device)).float())
94
- weights.append(weight)
 
95
 
96
  # Parse image prompts
97
  for prompt in images:
98
  path, weight = parse_prompt(prompt)
99
- img = Image.open(utils.fetch(path)).convert('RGB')
100
- clip_size = cloob.config['image_encoder']['image_size']
101
- img = resize_and_center_crop(img, (clip_size, clip_size))
102
- batch = TF.to_tensor(img)[None].to(device)
103
- embed = F.normalize(cloob.image_encoder(cloob.normalize(batch)).float(), dim=-1)
104
- target_embeds.append(embed)
105
  weights.append(weight)
106
 
107
  # Adjust weights and set seed
@@ -115,7 +100,7 @@ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method
115
  x_in = x.repeat([n_conds, 1, 1, 1])
116
  t_in = t.repeat([n_conds])
117
  embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
118
- vs = model(x_in, t_in, embed_in).view([n_conds, n, *x.shape[1:]])
119
  v = vs.mul(weights[:, None, None, None, None]).sum(0)
120
  return v
121
 
@@ -131,22 +116,19 @@ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method
131
 
132
  # 🏃‍♂️ Generate the output images
133
  batch_size = n
134
- x = torch.randn([n, n_ch, side_y, side_x], device=device)
135
  t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
136
  pil_ims = []
137
  for i in trange(0, n, batch_size):
138
  cur_batch_size = min(n - i, batch_size)
139
  out_latents = run(x[i:i + cur_batch_size], steps)
140
- outs = ae_model.decode(out_latents * torch.tensor(2.55).to(device))
141
  for j, out in enumerate(outs):
142
- pil_ims.append(utils.to_pil_image(out))
143
 
144
  return pil_ims
145
 
146
-
147
  # 🖌️ Interface: Gradio's brush to paint the UI
148
- # Gradio is used here to create a user-friendly interface for art generation.
149
-
150
  def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
151
  """
152
  💡 Gradio function to wrap image generation.
@@ -169,56 +151,12 @@ iface = gr.Interface(
169
  ],
170
  outputs=gr.Image(type="pil", label="Generated Image"),
171
  examples=[
172
- ["Virgin and Child, in the style of Jacopo Bellini"],
173
- ["Art Nouveau, in the style of John Singer Sargent"],
174
- ["Neoclassicism, in the style of Gustav Klimt"],
175
- ["Abstract Art, in the style of M.C. Escher"],
176
- ['Surrealism, in the style of Salvador Dali'],
177
- ["Romanesque Art, in the style of Leonardo da Vinci"],
178
- ["landscape"],
179
- ["portrait"],
180
- ["sculpture"],
181
- ["photo"],
182
- ["figurative"],
183
- ["illustration"],
184
- ["still life"],
185
- ["cityscape"],
186
- ["marina"],
187
- ["animal painting"],
188
- ["graffiti"],
189
- ["mythological painting"],
190
- ["battle painting"],
191
- ["self-portrait"],
192
- ["Impressionism, oil on canvas"],
193
- ["Katsushika Hokusai, The Dragon of Smoke Escaping from Mount Fuji"],
194
- ["Moon Light Sonata by Basuki Abdullah"],
195
- ["Two Trees by M.C. Escher"],
196
- ["Futurism, in the style of Wassily Kandinsky"],
197
- ["Surrealism, in the style of Edgar Degas"],
198
- ["Expressionism, in the style of Wassily Kandinsky"],
199
- ["Futurism, in the style of Egon Schiele"],
200
- ["Cubism, in the style of Gustav Klimt"],
201
- ["Op Art, in the style of Marc Chagall"],
202
- ["Romanticism, in the style of M.C. Escher"],
203
- ["Futurism, in the style of M.C. Escher"],
204
- ["Mannerism, in the style of Paul Klee"],
205
- ["High Renaissance, in the style of Rembrandt"],
206
- ["Magic Realism, in the style of Gustave Dore"],
207
- ["Realism, in the style of Jean-Michel Basquiat"],
208
- ["Art Nouveau, in the style of Paul Gauguin"],
209
- ["Avant-garde, in the style of Pierre-Auguste Renoir"],
210
- ["Baroque, in the style of Edward Hopper"],
211
- ["Post-Impressionism, in the style of Wassily Kandinsky"],
212
- ["Naturalism, in the style of Rene Magritte"],
213
- ["Constructivism, in the style of Paul Cezanne"],
214
- ["Abstract Expressionism, in the style of Henri Matisse"],
215
- ["Pop Art, in the style of Vincent van Gogh"],
216
- ["Futurism, in the style of Zdzislaw Beksinski"],
217
- ["Aaron Wacker, oil on canvas"]
218
  ],
219
- title='Art Generator and Style Mixer from 🧠 Cloob and 🎨 WikiArt - Visual Art Encyclopedia',
220
- description="Trained on images from the [WikiArt](https://www.wikiart.org/) dataset, comprised of visual arts",
221
- article='Model used is: [model card](https://huggingface.co/huggan/distill-ccld-wa).'
222
  )
223
 
224
  # 🚀 Launch the Gradio interface
 
13
  from torchvision import transforms
14
  from torchvision.transforms import functional as TF
15
  from tqdm import trange
16
+ from transformers import CLIPProcessor, CLIPModel
17
+ from vqvae import VQVAE2 # Autoencoder replacement
18
+ from diffusion_models import Diffusion # Swapped Diffusion model for DALL·E 2 based model
 
19
  from huggingface_hub import hf_hub_url, cached_download
20
  import gradio as gr # 🎨 The magic canvas for AI-powered image generation!
21
 
22
+ # 🖼️ Download the necessary model files from HuggingFace
23
+ vqvae_model_path = cached_download(hf_hub_url("huggingface/vqvae-2", filename="vqvae_model.ckpt"))
24
+ diffusion_model_path = cached_download(hf_hub_url("huggingface/dalle-2", filename="diffusion_model.ckpt"))
 
 
25
 
26
  # 📐 Utility Functions: Math and images, what could go wrong?
27
  # These functions help parse prompts and resize/crop images to fit nicely
 
30
  """
31
  🎯 Parses a prompt into text and weight.
32
  """
33
+ vals = prompt.rsplit(':', 1)
 
 
 
 
34
  vals = vals + ['', default_weight][len(vals):]
35
  return vals[0], float(vals[1])
36
 
 
42
  image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
43
  return TF.center_crop(image, size[::-1])
44
 
 
45
  # 🧠 Model loading: the brain of our operation! 🔥
 
46
 
47
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
48
  print('Using device:', device)
49
  print('loading models... 🛠️')
50
 
51
+ # Load CLIP model
52
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
53
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
 
 
 
 
 
 
 
 
54
 
55
+ # Load VQ-VAE-2 Autoencoder
56
+ vqvae = VQVAE2()
57
+ vqvae.load_state_dict(torch.load(vqvae_model_path))
58
+ vqvae.eval().requires_grad_(False).to(device)
 
 
59
 
60
+ # Load Diffusion Model
61
+ diffusion_model = Diffusion()
62
+ diffusion_model.load_state_dict(torch.load(diffusion_model_path))
63
+ diffusion_model = diffusion_model.to(device).eval().requires_grad_(False)
64
 
65
  # 🎨 The key function: Where the magic happens!
66
  # This is where we generate images based on text and image prompts
67
 
68
+ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method='ddim', eta=None):
69
  """
70
  🖼️ Generates a list of PIL images based on given text and image prompts.
71
  """
72
+ zero_embed = torch.zeros([1, clip_model.config.projection_dim], device=device)
73
  target_embeds, weights = [zero_embed], []
74
 
75
+ # Parse text prompts and encode with CLIP
76
  for prompt in prompts:
77
+ inputs = clip_processor(text=prompt, return_tensors="pt").to(device)
78
+ text_embed = clip_model.get_text_features(**inputs).float()
79
+ target_embeds.append(text_embed)
80
+ weights.append(1.0)
81
 
82
  # Parse image prompts
83
  for prompt in images:
84
  path, weight = parse_prompt(prompt)
85
+ img = Image.open(path).convert('RGB')
86
+ img = resize_and_center_crop(img, (224, 224))
87
+ inputs = clip_processor(images=img, return_tensors="pt").to(device)
88
+ image_embed = clip_model.get_image_features(**inputs).float()
89
+ target_embeds.append(image_embed)
 
90
  weights.append(weight)
91
 
92
  # Adjust weights and set seed
 
100
  x_in = x.repeat([n_conds, 1, 1, 1])
101
  t_in = t.repeat([n_conds])
102
  embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
103
+ vs = diffusion_model(x_in, t_in, embed_in).view([n_conds, n, *x.shape[1:]])
104
  v = vs.mul(weights[:, None, None, None, None]).sum(0)
105
  return v
106
 
 
116
 
117
  # 🏃‍♂️ Generate the output images
118
  batch_size = n
119
+ x = torch.randn([n, 3, 64, 64], device=device)
120
  t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
121
  pil_ims = []
122
  for i in trange(0, n, batch_size):
123
  cur_batch_size = min(n - i, batch_size)
124
  out_latents = run(x[i:i + cur_batch_size], steps)
125
+ outs = vqvae.decode(out_latents)
126
  for j, out in enumerate(outs):
127
+ pil_ims.append(transforms.ToPILImage()(out))
128
 
129
  return pil_ims
130
 
 
131
  # 🖌️ Interface: Gradio's brush to paint the UI
 
 
132
  def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
133
  """
134
  💡 Gradio function to wrap image generation.
 
151
  ],
152
  outputs=gr.Image(type="pil", label="Generated Image"),
153
  examples=[
154
+ ["A beautiful sunset over the ocean"],
155
+ ["A futuristic cityscape at night"],
156
+ ["A surreal dream-like landscape"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  ],
158
+ title='CLIP + Diffusion Model Image Generator',
159
+ description="Generate stunning images from text and image prompts using CLIP and a diffusion model.",
 
160
  )
161
 
162
  # 🚀 Launch the Gradio interface