Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LCMScheduler | |
from diffusers.schedulers import TCDScheduler | |
import spaces | |
from PIL import Image | |
import os | |
import re | |
from datetime import datetime | |
import random | |
import glob | |
SAFETY_CHECKER = True | |
checkpoints = { | |
"2-Step": ["pcm_{}_smallcfg_2step_converted.safetensors", 2, 0.0], | |
"4-Step": ["pcm_{}_smallcfg_4step_converted.safetensors", 4, 0.0], | |
"8-Step": ["pcm_{}_smallcfg_8step_converted.safetensors", 8, 0.0], | |
"16-Step": ["pcm_{}_smallcfg_16step_converted.safetensors", 16, 0.0], | |
"Normal CFG 4-Step": ["pcm_{}_normalcfg_4step_converted.safetensors", 4, 7.5], | |
"Normal CFG 8-Step": ["pcm_{}_normalcfg_8step_converted.safetensors", 8, 7.5], | |
"Normal CFG 16-Step": ["pcm_{}_normalcfg_16step_converted.safetensors", 16, 7.5], | |
"LCM-Like LoRA": ["pcm_{}_lcmlike_lora_converted.safetensors", 4, 0.0], | |
} | |
loaded = None | |
if torch.cuda.is_available(): | |
pipe_sdxl = StableDiffusionXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=torch.float16, | |
variant="fp16", | |
).to("cuda") | |
pipe_sd15 = StableDiffusionPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16" | |
).to("cuda") | |
if SAFETY_CHECKER: | |
from safety_checker import StableDiffusionSafetyChecker | |
from transformers import CLIPFeatureExtractor | |
safety_checker = StableDiffusionSafetyChecker.from_pretrained( | |
"CompVis/stable-diffusion-safety-checker" | |
).to("cuda") | |
feature_extractor = CLIPFeatureExtractor.from_pretrained( | |
"openai/clip-vit-base-patch32" | |
) | |
def check_nsfw_images(images: list[Image.Image]) -> tuple[list[Image.Image], list[bool]]: | |
safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda") | |
has_nsfw_concepts = safety_checker( | |
images=[images], clip_input=safety_checker_input.pixel_values.to("cuda") | |
) | |
return images, has_nsfw_concepts | |
def save_image(image: Image.Image, prompt: str) -> str: | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
clean_prompt = re.sub(r'[^\w\-_\. ]', '_', prompt)[:50] | |
filename = f"{timestamp}_{clean_prompt}.png" | |
image.save(filename) | |
return filename | |
def get_image_gallery(): | |
image_files = glob.glob("*.png") | |
return sorted([(file, file) for file in image_files], key=lambda x: os.path.getmtime(x[0]), reverse=True) | |
def generate_image( | |
prompt, | |
ckpt, | |
num_inference_steps, | |
progress=gr.Progress(track_tqdm=True), | |
mode="sdxl", | |
): | |
global loaded | |
checkpoint = checkpoints[ckpt][0].format(mode) | |
guidance_scale = checkpoints[ckpt][2] | |
pipe = pipe_sdxl if mode == "sdxl" else pipe_sd15 | |
if loaded != (ckpt + mode): | |
pipe.load_lora_weights( | |
"wangfuyun/PCM_Weights", weight_name=checkpoint, subfolder=mode | |
) | |
loaded = ckpt + mode | |
if ckpt == "LCM-Like LoRA": | |
pipe.scheduler = LCMScheduler() | |
else: | |
pipe.scheduler = TCDScheduler( | |
num_train_timesteps=1000, | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
timestep_spacing="trailing", | |
) | |
results = pipe( | |
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale | |
) | |
if SAFETY_CHECKER: | |
images, has_nsfw_concepts = check_nsfw_images(results.images) | |
if any(has_nsfw_concepts): | |
gr.Warning("NSFW content detected.") | |
return Image.new("RGB", (512, 512)), get_image_gallery() | |
filename = save_image(images[0], prompt) | |
return images[0], get_image_gallery() | |
filename = save_image(results.images[0], prompt) | |
return results.images[0], get_image_gallery() | |
def update_steps(ckpt): | |
num_inference_steps = checkpoints[ckpt][1] | |
if ckpt == "LCM-Like LoRA": | |
return gr.update(interactive=True, value=num_inference_steps) | |
return gr.update(interactive=False, value=num_inference_steps) | |
css = """ | |
.gradio-container { | |
max-width: 60rem !important; | |
} | |
""" | |
art_styles = ['Impressionist', 'Cubist', 'Surrealist', 'Abstract Expressionist', 'Pop Art', 'Minimalist', 'Baroque', 'Art Nouveau', 'Pointillist', 'Fauvism'] | |
examples = [ | |
[f"{random.choice(art_styles)} painting of a majestic lighthouse on a rocky coast. Use bold brushstrokes and a vibrant color palette to capture the interplay of light and shadow as the lighthouse beam cuts through a stormy night sky.", "4-Step", 4], | |
[f"{random.choice(art_styles)} still life featuring a pair of vintage eyeglasses. Focus on the intricate details of the frames and lenses, using a warm color scheme to evoke a sense of nostalgia and wisdom.", "4-Step", 4], | |
[f"{random.choice(art_styles)} depiction of a rustic wooden stool in a sunlit artist's studio. Emphasize the texture of the wood and the interplay of light and shadow, using a mix of earthy tones and highlights.", "4-Step", 4], | |
[f"{random.choice(art_styles)} scene viewed through an ornate window frame. Contrast the intricate details of the window with a dreamy, soft-focus landscape beyond, using a palette that transitions from cool interior tones to warm exterior hues.", "4-Step", 4], | |
[f"{random.choice(art_styles)} close-up study of interlaced fingers. Use a monochromatic color scheme to emphasize the form and texture of the hands, with dramatic lighting to create depth and emotion.", "4-Step", 4], | |
[f"{random.choice(art_styles)} composition featuring a set of dice in motion. Capture the energy and randomness of the throw, using a dynamic color palette and blurred lines to convey movement.", "4-Step", 4], | |
[f"{random.choice(art_styles)} interpretation of heaven. Create an ethereal atmosphere with soft, billowing clouds and radiant light, using a palette of celestial blues, golds, and whites.", "4-Step", 4], | |
[f"{random.choice(art_styles)} portrayal of an ancient, mystical gate. Combine architectural details with elements of fantasy, using a rich, jewel-toned palette to create an air of mystery and magic.", "4-Step", 4], | |
[f"{random.choice(art_styles)} portrait of a curious cat. Focus on capturing the feline's expressive eyes and sleek form, using a mix of bold and subtle colors to bring out the cat's personality.", "4-Step", 4], | |
[f"{random.choice(art_styles)} abstract representation of toes in sand. Use textured brushstrokes to convey the feeling of warm sand, with a palette inspired by a sun-drenched beach.", "4-Step", 4] | |
] | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown( | |
""" | |
# Phased Consistency Model | |
Phased Consistency Model (PCM) is an image generation technique that addresses the limitations of the Latent Consistency Model (LCM) in high-resolution and text-conditioned image generation. | |
PCM outperforms LCM across various generation settings and achieves state-of-the-art results in both image and video generation. | |
[[paper](https://huggingface.co/papers/2405.18407)] [[arXiv](https://arxiv.org/abs/2405.18407)] [[code](https://github.com/G-U-N/Phased-Consistency-Model)] [[project page](https://g-u-n.github.io/projects/pcm)] | |
""" | |
) | |
with gr.Group(): | |
with gr.Row(): | |
prompt = gr.Textbox(label="Prompt", scale=8) | |
ckpt = gr.Dropdown( | |
label="Select inference steps", | |
choices=list(checkpoints.keys()), | |
value="4-Step", | |
) | |
steps = gr.Slider( | |
label="Number of Inference Steps", | |
minimum=1, | |
maximum=20, | |
step=1, | |
value=4, | |
interactive=False, | |
) | |
ckpt.change( | |
fn=update_steps, | |
inputs=[ckpt], | |
outputs=[steps], | |
queue=False, | |
show_progress=False, | |
) | |
submit_sdxl = gr.Button("Run on SDXL", scale=1) | |
submit_sd15 = gr.Button("Run on SD15", scale=1) | |
img = gr.Image(label="PCM Image") | |
gallery = gr.Gallery(label="Generated Images", show_label=True, columns=4, height="auto") | |
gr.Examples( | |
examples=examples, | |
inputs=[prompt, ckpt, steps], | |
outputs=[img, gallery], | |
fn=generate_image, | |
cache_examples=True, | |
) | |
gr.on( | |
fn=generate_image, | |
triggers=[ckpt.change, prompt.submit, submit_sdxl.click], | |
inputs=[prompt, ckpt, steps], | |
outputs=[img, gallery], | |
) | |
gr.on( | |
fn=lambda *args: generate_image(*args, mode="sd15"), | |
triggers=[submit_sd15.click], | |
inputs=[prompt, ckpt, steps], | |
outputs=[img, gallery], | |
) | |
demo.load(fn=get_image_gallery, outputs=gallery) | |
demo.queue(api_open=False).launch(show_api=False) |