Spaces:
Runtime error
Runtime error
# 🚀 Import all necessary libraries | |
import os | |
import argparse | |
from functools import partial | |
from pathlib import Path | |
import sys | |
import random | |
from omegaconf import OmegaConf | |
from PIL import Image | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torchvision import transforms | |
from torchvision.transforms import functional as TF | |
from tqdm import trange | |
from transformers import CLIPProcessor, CLIPModel | |
from vqvae import VQVAE2 # Autoencoder replacement | |
from diffusion_models import Diffusion # Swapped Diffusion model for DALL·E 2 based model | |
from huggingface_hub import hf_hub_url, cached_download | |
import gradio as gr # 🎨 The magic canvas for AI-powered image generation! | |
# 🖼️ Download the necessary model files from HuggingFace | |
vqvae_model_path = cached_download(hf_hub_url("huggingface/vqvae-2", filename="vqvae_model.ckpt")) | |
diffusion_model_path = cached_download(hf_hub_url("huggingface/dalle-2", filename="diffusion_model.ckpt")) | |
# 📐 Utility Functions: Math and images, what could go wrong? | |
# These functions help parse prompts and resize/crop images to fit nicely | |
def parse_prompt(prompt, default_weight=3.): | |
""" | |
🎯 Parses a prompt into text and weight. | |
""" | |
vals = prompt.rsplit(':', 1) | |
vals = vals + ['', default_weight][len(vals):] | |
return vals[0], float(vals[1]) | |
def resize_and_center_crop(image, size): | |
""" | |
✂️ Resize and crop image to center it beautifully. | |
""" | |
fac = max(size[0] / image.size[0], size[1] / image.size[1]) | |
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS) | |
return TF.center_crop(image, size[::-1]) | |
# 🧠 Model loading: the brain of our operation! 🔥 | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
print('Using device:', device) | |
print('loading models... 🛠️') | |
# Load CLIP model | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
# Load VQ-VAE-2 Autoencoder | |
vqvae = VQVAE2() | |
vqvae.load_state_dict(torch.load(vqvae_model_path)) | |
vqvae.eval().requires_grad_(False).to(device) | |
# Load Diffusion Model | |
diffusion_model = Diffusion() | |
diffusion_model.load_state_dict(torch.load(diffusion_model_path)) | |
diffusion_model = diffusion_model.to(device).eval().requires_grad_(False) | |
# 🎨 The key function: Where the magic happens! | |
# This is where we generate images based on text and image prompts | |
def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method='ddim', eta=None): | |
""" | |
🖼️ Generates a list of PIL images based on given text and image prompts. | |
""" | |
zero_embed = torch.zeros([1, clip_model.config.projection_dim], device=device) | |
target_embeds, weights = [zero_embed], [] | |
# Parse text prompts and encode with CLIP | |
for prompt in prompts: | |
inputs = clip_processor(text=prompt, return_tensors="pt").to(device) | |
text_embed = clip_model.get_text_features(**inputs).float() | |
target_embeds.append(text_embed) | |
weights.append(1.0) | |
# Parse image prompts | |
for prompt in images: | |
path, weight = parse_prompt(prompt) | |
img = Image.open(path).convert('RGB') | |
img = resize_and_center_crop(img, (224, 224)) | |
inputs = clip_processor(images=img, return_tensors="pt").to(device) | |
image_embed = clip_model.get_image_features(**inputs).float() | |
target_embeds.append(image_embed) | |
weights.append(weight) | |
# Adjust weights and set seed | |
weights = torch.tensor([1 - sum(weights), *weights], device=device) | |
torch.manual_seed(seed) | |
# 💡 Model function with classifier-free guidance | |
def cfg_model_fn(x, t): | |
n = x.shape[0] | |
n_conds = len(target_embeds) | |
x_in = x.repeat([n_conds, 1, 1, 1]) | |
t_in = t.repeat([n_conds]) | |
embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0) | |
vs = diffusion_model(x_in, t_in, embed_in).view([n_conds, n, *x.shape[1:]]) | |
v = vs.mul(weights[:, None, None, None, None]).sum(0) | |
return v | |
# 🎞️ Run the sampler to generate images | |
def run(x, steps): | |
if method == 'ddpm': | |
return sampling.sample(cfg_model_fn, x, steps, 1., {}) | |
if method == 'ddim': | |
return sampling.sample(cfg_model_fn, x, steps, eta, {}) | |
if method == 'plms': | |
return sampling.plms_sample(cfg_model_fn, x, steps, {}) | |
assert False | |
# 🏃♂️ Generate the output images | |
batch_size = n | |
x = torch.randn([n, 3, 64, 64], device=device) | |
t = torch.linspace(1, 0, steps + 1, device=device)[:-1] | |
pil_ims = [] | |
for i in trange(0, n, batch_size): | |
cur_batch_size = min(n - i, batch_size) | |
out_latents = run(x[i:i + cur_batch_size], steps) | |
outs = vqvae.decode(out_latents) | |
for j, out in enumerate(outs): | |
pil_ims.append(transforms.ToPILImage()(out)) | |
return pil_ims | |
# 🖌️ Interface: Gradio's brush to paint the UI | |
def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'): | |
""" | |
💡 Gradio function to wrap image generation. | |
""" | |
if seed is None: | |
seed = random.randint(0, 10000) | |
prompts = [prompt] | |
im_prompts = [] | |
if im_prompt is not None: | |
im_prompts = [im_prompt] | |
pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method) | |
return pil_ims[0] | |
# 🖼️ Gradio UI: The interface where users can input text or image prompts | |
iface = gr.Interface( | |
fn=gen_ims, | |
inputs=[ | |
gr.Textbox(label="Text prompt"), | |
gr.Image(optional=True, label="Image prompt", type='filepath') | |
], | |
outputs=gr.Image(type="pil", label="Generated Image"), | |
examples=[ | |
["A beautiful sunset over the ocean"], | |
["A futuristic cityscape at night"], | |
["A surreal dream-like landscape"] | |
], | |
title='CLIP + Diffusion Model Image Generator', | |
description="Generate stunning images from text and image prompts using CLIP and a diffusion model.", | |
) | |
# 🚀 Launch the Gradio interface | |
iface.launch(enable_queue=True) | |