Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
os.system("git clone --recursive https://github.com/JD-P/cloob-latent-diffusion")
|
| 4 |
+
os.system("cd cloob-latent-diffusion;pip install omegaconf pillow pytorch-lightning einops wandb ftfy regex ./CLIP")
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
from functools import partial
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import sys
|
| 10 |
+
sys.path.append('./cloob-latent-diffusion')
|
| 11 |
+
sys.path.append('./cloob-latent-diffusion/cloob-training')
|
| 12 |
+
sys.path.append('./cloob-latent-diffusion/latent-diffusion')
|
| 13 |
+
sys.path.append('./cloob-latent-diffusion/taming-transformers')
|
| 14 |
+
sys.path.append('./cloob-latent-diffusion/v-diffusion-pytorch')
|
| 15 |
+
from omegaconf import OmegaConf
|
| 16 |
+
from PIL import Image
|
| 17 |
+
import torch
|
| 18 |
+
from torch import nn
|
| 19 |
+
from torch.nn import functional as F
|
| 20 |
+
from torchvision import transforms
|
| 21 |
+
from torchvision.transforms import functional as TF
|
| 22 |
+
from tqdm import trange
|
| 23 |
+
from CLIP import clip
|
| 24 |
+
from cloob_training import model_pt, pretrained
|
| 25 |
+
import ldm.models.autoencoder
|
| 26 |
+
from diffusion import sampling, utils
|
| 27 |
+
import train_latent_diffusion as train
|
| 28 |
+
from huggingface_hub import hf_hub_url, cached_download
|
| 29 |
+
import random
|
| 30 |
+
|
| 31 |
+
# Download the model files
|
| 32 |
+
checkpoint = cached_download(hf_hub_url("huggan/distill-ccld-wa", filename="model_student.ckpt"))
|
| 33 |
+
ae_model_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt"))
|
| 34 |
+
ae_config_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml"))
|
| 35 |
+
|
| 36 |
+
# Define a few utility functions
|
| 37 |
+
|
| 38 |
+
def parse_prompt(prompt, default_weight=3.):
|
| 39 |
+
if prompt.startswith('http://') or prompt.startswith('https://'):
|
| 40 |
+
vals = prompt.rsplit(':', 2)
|
| 41 |
+
vals = [vals[0] + ':' + vals[1], *vals[2:]]
|
| 42 |
+
else:
|
| 43 |
+
vals = prompt.rsplit(':', 1)
|
| 44 |
+
vals = vals + ['', default_weight][len(vals):]
|
| 45 |
+
return vals[0], float(vals[1])
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def resize_and_center_crop(image, size):
|
| 49 |
+
fac = max(size[0] / image.size[0], size[1] / image.size[1])
|
| 50 |
+
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
|
| 51 |
+
return TF.center_crop(image, size[::-1])
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Load the models
|
| 55 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 56 |
+
print('Using device:', device)
|
| 57 |
+
print('loading models')
|
| 58 |
+
|
| 59 |
+
# autoencoder
|
| 60 |
+
ae_config = OmegaConf.load(ae_config_path)
|
| 61 |
+
ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params)
|
| 62 |
+
ae_model.eval().requires_grad_(False).to(device)
|
| 63 |
+
ae_model.load_state_dict(torch.load(ae_model_path))
|
| 64 |
+
n_ch, side_y, side_x = 4, 32, 32
|
| 65 |
+
|
| 66 |
+
# diffusion model
|
| 67 |
+
model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
|
| 68 |
+
model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
|
| 69 |
+
model = model.to(device).eval().requires_grad_(False)
|
| 70 |
+
|
| 71 |
+
# CLOOB
|
| 72 |
+
cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
|
| 73 |
+
cloob = model_pt.get_pt_model(cloob_config)
|
| 74 |
+
checkpoint = pretrained.download_checkpoint(cloob_config)
|
| 75 |
+
cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint))
|
| 76 |
+
cloob.eval().requires_grad_(False).to(device)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# The key function: returns a list of n PIL images
|
| 80 |
+
def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15,
|
| 81 |
+
method='plms', eta=None):
|
| 82 |
+
zero_embed = torch.zeros([1, cloob.config['d_embed']], device=device)
|
| 83 |
+
target_embeds, weights = [zero_embed], []
|
| 84 |
+
|
| 85 |
+
for prompt in prompts:
|
| 86 |
+
txt, weight = parse_prompt(prompt)
|
| 87 |
+
target_embeds.append(cloob.text_encoder(cloob.tokenize(txt).to(device)).float())
|
| 88 |
+
weights.append(weight)
|
| 89 |
+
|
| 90 |
+
for prompt in images:
|
| 91 |
+
path, weight = parse_prompt(prompt)
|
| 92 |
+
img = Image.open(utils.fetch(path)).convert('RGB')
|
| 93 |
+
clip_size = cloob.config['image_encoder']['image_size']
|
| 94 |
+
img = resize_and_center_crop(img, (clip_size, clip_size))
|
| 95 |
+
batch = TF.to_tensor(img)[None].to(device)
|
| 96 |
+
embed = F.normalize(cloob.image_encoder(cloob.normalize(batch)).float(), dim=-1)
|
| 97 |
+
target_embeds.append(embed)
|
| 98 |
+
weights.append(weight)
|
| 99 |
+
|
| 100 |
+
weights = torch.tensor([1 - sum(weights), *weights], device=device)
|
| 101 |
+
|
| 102 |
+
torch.manual_seed(seed)
|
| 103 |
+
|
| 104 |
+
def cfg_model_fn(x, t):
|
| 105 |
+
n = x.shape[0]
|
| 106 |
+
n_conds = len(target_embeds)
|
| 107 |
+
x_in = x.repeat([n_conds, 1, 1, 1])
|
| 108 |
+
t_in = t.repeat([n_conds])
|
| 109 |
+
clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
|
| 110 |
+
vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
|
| 111 |
+
v = vs.mul(weights[:, None, None, None, None]).sum(0)
|
| 112 |
+
return v
|
| 113 |
+
|
| 114 |
+
def run(x, steps):
|
| 115 |
+
if method == 'ddpm':
|
| 116 |
+
return sampling.sample(cfg_model_fn, x, steps, 1., {})
|
| 117 |
+
if method == 'ddim':
|
| 118 |
+
return sampling.sample(cfg_model_fn, x, steps, eta, {})
|
| 119 |
+
if method == 'prk':
|
| 120 |
+
return sampling.prk_sample(cfg_model_fn, x, steps, {})
|
| 121 |
+
if method == 'plms':
|
| 122 |
+
return sampling.plms_sample(cfg_model_fn, x, steps, {})
|
| 123 |
+
if method == 'pie':
|
| 124 |
+
return sampling.pie_sample(cfg_model_fn, x, steps, {})
|
| 125 |
+
if method == 'plms2':
|
| 126 |
+
return sampling.plms2_sample(cfg_model_fn, x, steps, {})
|
| 127 |
+
assert False
|
| 128 |
+
|
| 129 |
+
batch_size = n
|
| 130 |
+
x = torch.randn([n, n_ch, side_y, side_x], device=device)
|
| 131 |
+
t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
|
| 132 |
+
steps = utils.get_spliced_ddpm_cosine_schedule(t)
|
| 133 |
+
pil_ims = []
|
| 134 |
+
for i in trange(0, n, batch_size):
|
| 135 |
+
cur_batch_size = min(n - i, batch_size)
|
| 136 |
+
out_latents = run(x[i:i+cur_batch_size], steps)
|
| 137 |
+
outs = ae_model.decode(out_latents * torch.tensor(2.55).to(device))
|
| 138 |
+
for j, out in enumerate(outs):
|
| 139 |
+
pil_ims.append(utils.to_pil_image(out))
|
| 140 |
+
|
| 141 |
+
return pil_ims
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
import gradio as gr
|
| 145 |
+
|
| 146 |
+
def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
|
| 147 |
+
if seed == None :
|
| 148 |
+
seed = random.randint(0, 10000)
|
| 149 |
+
print( prompt, im_prompt, seed, n_steps)
|
| 150 |
+
prompts = [prompt]
|
| 151 |
+
im_prompts = []
|
| 152 |
+
if im_prompt != None:
|
| 153 |
+
im_prompts = [im_prompt]
|
| 154 |
+
pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
|
| 155 |
+
return pil_ims[0]
|
| 156 |
+
|
| 157 |
+
iface = gr.Interface(fn=gen_ims,
|
| 158 |
+
inputs=[#gr.inputs.Slider(minimum=1, maximum=1, step=1, default=1,label="Number of images"),
|
| 159 |
+
#gr.inputs.Slider(minimum=0, maximum=200, step=1, label='Random seed', default=0),
|
| 160 |
+
gr.inputs.Textbox(label="Text prompt"),
|
| 161 |
+
gr.inputs.Image(optional=True, label="Image prompt", type='filepath'),
|
| 162 |
+
#gr.inputs.Slider(minimum=10, maximum=35, step=1, default=15,label="Number of steps")
|
| 163 |
+
],
|
| 164 |
+
outputs=[gr.outputs.Image(type="pil", label="Generated Image")],
|
| 165 |
+
examples=[
|
| 166 |
+
["Impressionism, oil on canvas"],
|
| 167 |
+
["Futurism, in the style of Wassily Kandinsky"],
|
| 168 |
+
["Art Nouveau, in the style of John Singer Sargent"],
|
| 169 |
+
["Surrealism, in the style of Edgar Degas"],
|
| 170 |
+
["Expressionism, in the style of Wassily Kandinsky"],
|
| 171 |
+
["Futurism, in the style of Egon Schiele"],
|
| 172 |
+
["Neoclassicism, in the style of Gustav Klimt"],
|
| 173 |
+
["Cubism, in the style of Gustav Klimt"],
|
| 174 |
+
["Op Art, in the style of Marc Chagall"],
|
| 175 |
+
["Romanticism, in the style of M.C. Escher"],
|
| 176 |
+
["Futurism, in the style of M.C. Escher"],
|
| 177 |
+
["Abstract Art, in the style of M.C. Escher"],
|
| 178 |
+
["Mannerism, in the style of Paul Klee"],
|
| 179 |
+
["Romanesque Art, in the style of Leonardo da Vinci"],
|
| 180 |
+
["High Renaissance, in the style of Rembrandt"],
|
| 181 |
+
["Magic Realism, in the style of Gustave Dore"],
|
| 182 |
+
["Realism, in the style of Jean-Michel Basquiat"],
|
| 183 |
+
["Art Nouveau, in the style of Paul Gauguin"],
|
| 184 |
+
["Avant-garde, in the style of Pierre-Auguste Renoir"],
|
| 185 |
+
["Baroque, in the style of Edward Hopper"],
|
| 186 |
+
["Post-Impressionism, in the style of Wassily Kandinsky"],
|
| 187 |
+
["Naturalism, in the style of Rene Magritte"],
|
| 188 |
+
["Constructivism, in the style of Paul Cezanne"],
|
| 189 |
+
["Abstract Expressionism, in the style of Henri Matisse"],
|
| 190 |
+
["Pop Art, in the style of Vincent van Gogh"],
|
| 191 |
+
["Futurism, in the style of Wassily Kandinsky"],
|
| 192 |
+
["Futurism, in the style of Zdzislaw Beksinski"],
|
| 193 |
+
['Surrealism, in the style of Salvador Dali'],
|
| 194 |
+
["Aaron Wacker, oil on canvas"]
|
| 195 |
+
],
|
| 196 |
+
title='Art Generator and Style Mixer from 🧠 Cloob and 🎨 WikiArt - Visual Art Encyclopedia:',
|
| 197 |
+
description="Trained on images from the [WikiArt](https://www.wikiart.org/) dataset, comprised of visual arts",
|
| 198 |
+
article = 'Model used is: [model card](https://huggingface.co/huggan/distill-ccld-wa)..'
|
| 199 |
+
|
| 200 |
+
)
|
| 201 |
+
iface.launch(enable_queue=True) # , debug=True for colab debugging
|