Fabrice-TIERCELIN's picture
Optimize the GPU use
c3ab43a verified
raw
history blame
18.2 kB
import gradio as gr
import numpy as np
import time
import math
import random
import torch
import spaces
from diffusers import StableDiffusionXLInpaintPipeline
from PIL import Image, ImageFilter
max_64_bit_int = 2**63 - 1
DESCRIPTION="""
<h1 style="text-align: center;">Outpainting demo</h1>
<p style="text-align: center;">This uses code by Fabrice TIERCELIN</p>
<br/>
<a href='https://huggingface.co/spaces/clinteroni/outpainting-with-differential-diffusion-demo?duplicate=true'><img src='https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14'></a>
<br/>
"""
if torch.cuda.is_available():
device = "cuda"
floatType = torch.float16
variant = "fp16"
else:
device = "cpu"
floatType = torch.float32
variant = None
DESCRIPTION+=f"<p>Running on {device}</p>"
pipe = StableDiffusionXLInpaintPipeline.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype = floatType, variant = variant)
pipe = pipe.to(device)
def update_seed(is_randomize_seed, seed):
if is_randomize_seed:
return random.randint(0, max_64_bit_int)
return seed
def toggle_debug(is_debug_mode):
return [gr.update(visible = is_debug_mode)] * 3
def noise_color(color, noise):
return color + random.randint(- noise, noise)
def check(
input_image,
enlarge_top,
enlarge_right,
enlarge_bottom,
enlarge_left,
prompt,
negative_prompt,
smooth_border,
num_inference_steps,
guidance_scale,
image_guidance_scale,
strength,
denoising_steps,
is_randomize_seed,
seed,
debug_mode,
progress = gr.Progress()):
if input_image is None:
raise gr.Error("Please provide an image.")
if prompt is None or prompt == "":
raise gr.Error("Please provide a prompt input.")
if (not (enlarge_top is None)) and enlarge_top < 0:
raise gr.Error("Please provide positive top margin.")
if (not (enlarge_right is None)) and enlarge_right < 0:
raise gr.Error("Please provide positive right margin.")
if (not (enlarge_bottom is None)) and enlarge_bottom < 0:
raise gr.Error("Please provide positive bottom margin.")
if (not (enlarge_left is None)) and enlarge_left < 0:
raise gr.Error("Please provide positive left margin.")
if (
(enlarge_top is None or enlarge_top == 0)
and (enlarge_right is None or enlarge_right == 0)
and (enlarge_bottom is None or enlarge_bottom == 0)
and (enlarge_left is None or enlarge_left == 0)
):
raise gr.Error("At least one border must be enlarged.")
def uncrop(
input_image,
enlarge_top,
enlarge_right,
enlarge_bottom,
enlarge_left,
prompt,
negative_prompt,
smooth_border,
num_inference_steps,
guidance_scale,
image_guidance_scale,
strength,
denoising_steps,
is_randomize_seed,
seed,
debug_mode,
progress = gr.Progress()):
check(
input_image,
enlarge_top,
enlarge_right,
enlarge_bottom,
enlarge_left,
prompt,
negative_prompt,
smooth_border,
num_inference_steps,
guidance_scale,
image_guidance_scale,
strength,
denoising_steps,
is_randomize_seed,
seed,
debug_mode
)
start = time.time()
progress(0, desc = "Preparing data...")
if enlarge_top is None or enlarge_top == "":
enlarge_top = 0
if enlarge_right is None or enlarge_right == "":
enlarge_right = 0
if enlarge_bottom is None or enlarge_bottom == "":
enlarge_bottom = 0
if enlarge_left is None or enlarge_left == "":
enlarge_left = 0
if negative_prompt is None:
negative_prompt = ""
if smooth_border is None:
smooth_border = 0
if num_inference_steps is None:
num_inference_steps = 50
if guidance_scale is None:
guidance_scale = 7
if image_guidance_scale is None:
image_guidance_scale = 1.5
if strength is None:
strength = 0.99
if denoising_steps is None:
denoising_steps = 1000
if seed is None:
seed = random.randint(0, max_64_bit_int)
random.seed(seed)
torch.manual_seed(seed)
original_height, original_width, original_channel = np.array(input_image).shape
output_width = enlarge_left + original_width + enlarge_right
output_height = enlarge_top + original_height + enlarge_bottom
# Enlarged image
enlarged_image = Image.new(mode = input_image.mode, size = (original_width, original_height), color = "black")
enlarged_image.paste(input_image, (0, 0))
enlarged_image = enlarged_image.resize((output_width, output_height))
enlarged_image = enlarged_image.filter(ImageFilter.BoxBlur(20))
enlarged_image.paste(input_image, (enlarge_left, enlarge_top))
horizontally_mirrored_input_image = input_image.transpose(Image.FLIP_LEFT_RIGHT).resize((original_width * 2, original_height))
enlarged_image.paste(horizontally_mirrored_input_image, (enlarge_left - (original_width * 2), enlarge_top))
enlarged_image.paste(horizontally_mirrored_input_image, (enlarge_left + original_width, enlarge_top))
vertically_mirrored_input_image = input_image.transpose(Image.FLIP_TOP_BOTTOM).resize((original_width, original_height * 2))
enlarged_image.paste(vertically_mirrored_input_image, (enlarge_left, enlarge_top - (original_height * 2)))
enlarged_image.paste(vertically_mirrored_input_image, (enlarge_left, enlarge_top + original_height))
returned_input_image = input_image.transpose(Image.ROTATE_180).resize((original_width * 2, original_height * 2))
enlarged_image.paste(returned_input_image, (enlarge_left - (original_width * 2), enlarge_top - (original_height * 2)))
enlarged_image.paste(returned_input_image, (enlarge_left - (original_width * 2), enlarge_top + original_height))
enlarged_image.paste(returned_input_image, (enlarge_left + original_width, enlarge_top - (original_height * 2)))
enlarged_image.paste(returned_input_image, (enlarge_left + original_width, enlarge_top + original_height))
enlarged_image = enlarged_image.filter(ImageFilter.BoxBlur(20))
# Noise image
noise_image = Image.new(mode = input_image.mode, size = (output_width, output_height), color = "black")
enlarged_pixels = enlarged_image.load()
for i in range(output_width):
for j in range(output_height):
enlarged_pixel = enlarged_pixels[i, j]
noise = min(max(enlarge_left - i, i - (enlarge_left + original_width), enlarge_top - j, j - (enlarge_top + original_height), 0), 255)
noise_image.putpixel((i, j), (noise_color(enlarged_pixel[0], noise), noise_color(enlarged_pixel[1], noise), noise_color(enlarged_pixel[2], noise), 255))
enlarged_image.paste(noise_image, (0, 0))
enlarged_image.paste(input_image, (enlarge_left, enlarge_top))
# Mask
mask_image = Image.new(mode = input_image.mode, size = (output_width, output_height), color = (255, 255, 255, 0))
black_mask = Image.new(mode = input_image.mode, size = (original_width - smooth_border, original_height - smooth_border), color = (0, 0, 0, 0))
mask_image.paste(black_mask, (enlarge_left + (smooth_border // 2), enlarge_top + (smooth_border // 2)))
mask_image = mask_image.filter(ImageFilter.BoxBlur((smooth_border // 2)))
# Limited to 1 million pixels
if 1024 * 1024 < output_width * output_height:
factor = ((1024 * 1024) / (output_width * output_height))**0.5
process_width = math.floor(output_width * factor)
process_height = math.floor(output_height * factor)
limitation = " Due to technical limitations, the image has been downscaled and then upscaled.";
else:
process_width = output_width
process_height = output_height
limitation = "";
# Width and height must be multiple of 8
if (process_width % 8) != 0 or (process_height % 8) != 0:
if ((process_width - (process_width % 8) + 8) * (process_height - (process_height % 8) + 8)) <= (1024 * 1024):
process_width = process_width - (process_width % 8) + 8
process_height = process_height - (process_height % 8) + 8
elif (process_height % 8) <= (process_width % 8) and ((process_width - (process_width % 8) + 8) * process_height) <= (1024 * 1024):
process_width = process_width - (process_width % 8) + 8
process_height = process_height - (process_height % 8)
elif (process_width % 8) <= (process_height % 8) and (process_width * (process_height - (process_height % 8) + 8)) <= (1024 * 1024):
process_width = process_width - (process_width % 8)
process_height = process_height - (process_height % 8) + 8
else:
process_width = process_width - (process_width % 8)
process_height = process_height - (process_height % 8)
progress(None, desc = "Processing...")
output_image = uncrop_on_gpu(
seed,
process_width,
process_height,
prompt,
negative_prompt,
enlarged_image,
mask_image,
num_inference_steps,
guidance_scale,
image_guidance_scale,
strength,
denoising_steps
)
if limitation != "":
output_image = output_image.resize((output_width, output_height))
if debug_mode == False:
input_image = None
enlarged_image = None
mask_image = None
end = time.time()
secondes = int(end - start)
minutes = math.floor(secondes / 60)
secondes = secondes - (minutes * 60)
hours = math.floor(minutes / 60)
minutes = minutes - (hours * 60)
return [
output_image,
("Start again to get a different result. " if is_randomize_seed else "") + "The new image is " + str(output_width) + " pixels large and " + str(output_height) + " pixels high, so an image of " + f'{output_width * output_height:,}' + " pixels. The image has been generated in " + ((str(hours) + " h, ") if hours != 0 else "") + ((str(minutes) + " min, ") if hours != 0 or minutes != 0 else "") + str(secondes) + " sec." + limitation,
input_image,
enlarged_image,
mask_image
]
@spaces.GPU(duration=120)
def uncrop_on_gpu(
seed,
process_width,
process_height,
prompt,
negative_prompt,
enlarged_image,
mask_image,
num_inference_steps,
guidance_scale,
image_guidance_scale,
strength,
denoising_steps
):
return pipe(
seeds = [seed],
width = process_width,
height = process_height,
prompt = prompt,
negative_prompt = negative_prompt,
image = enlarged_image,
mask_image = mask_image,
num_inference_steps = num_inference_steps,
guidance_scale = guidance_scale,
image_guidance_scale = image_guidance_scale,
strength = strength,
denoising_steps = denoising_steps,
show_progress_bar = True
).images[0]
with gr.Blocks() as interface:
gr.HTML(
DESCRIPTION
)
with gr.Row():
with gr.Column():
dummy_1 = gr.Label(visible = False)
with gr.Column():
enlarge_top = gr.Number(minimum = 0, value = 64, precision = 0, label = "Uncrop on top ⬆️", info = "in pixels")
with gr.Column():
dummy_2 = gr.Label(visible = False)
with gr.Row():
with gr.Column():
enlarge_left = gr.Number(minimum = 0, value = 64, precision = 0, label = "Uncrop on left ⬅️", info = "in pixels")
with gr.Column():
input_image = gr.Image(label = "Your image", sources = ["upload", "webcam", "clipboard"], type = "pil")
with gr.Column():
enlarge_right = gr.Number(minimum = 0, value = 64, precision = 0, label = "Uncrop on right ➡️", info = "in pixels")
with gr.Row():
with gr.Column():
dummy_3 = gr.Label(visible = False)
with gr.Column():
enlarge_bottom = gr.Number(minimum = 0, value = 64, precision = 0, label = "Uncrop on bottom ⬇️", info = "in pixels")
with gr.Column():
dummy_4 = gr.Label(visible = False)
with gr.Row():
prompt = gr.Textbox(label = "Prompt", info = "Describe the subject, the background and the style of image; 77 token limit", placeholder = "Describe what you want to see in the entire image", lines = 2)
with gr.Row():
with gr.Accordion("Advanced options", open = False):
negative_prompt = gr.Textbox(label = "Negative prompt", placeholder = "Describe what you do NOT want to see in the entire image", value = 'Border, frame, painting, scribbling, smear, noise, blur, watermark')
smooth_border = gr.Slider(minimum = 0, maximum = 1024, value = 0, step = 2, label = "Smooth border", info = "lower=preserve original, higher=seamless")
num_inference_steps = gr.Slider(minimum = 10, maximum = 100, value = 50, step = 1, label = "Number of inference steps", info = "lower=faster, higher=image quality")
guidance_scale = gr.Slider(minimum = 1, maximum = 13, value = 7, step = 0.1, label = "Classifier-Free Guidance Scale", info = "lower=image quality, higher=follow the prompt")
image_guidance_scale = gr.Slider(minimum = 1, value = 1.5, step = 0.1, label = "Image Guidance Scale", info = "lower=image quality, higher=follow the image")
strength = gr.Slider(value = 0.99, minimum = 0.01, maximum = 1.0, step = 0.01, label = "Strength", info = "lower=follow the original area (discouraged), higher=redraw from scratch")
denoising_steps = gr.Number(minimum = 0, value = 1000, step = 1, label = "Denoising", info = "lower=irrelevant result, higher=relevant result")
randomize_seed = gr.Checkbox(label = "\U0001F3B2 Randomize seed", value = True, info = "If checked, result is always different")
seed = gr.Slider(minimum = 0, maximum = max_64_bit_int, step = 1, randomize = True, label = "Seed")
debug_mode = gr.Checkbox(label = "Debug mode", value = False, info = "Show intermediate results")
with gr.Row():
submit = gr.Button("🚀 Outpaint", variant = "primary")
with gr.Row():
uncropped_image = gr.Image(label = "Outpainted image")
with gr.Row():
information = gr.HTML()
with gr.Row():
original_image = gr.Image(label = "Original image", visible = False)
with gr.Row():
enlarged_image = gr.Image(label = "Enlarged image", visible = False)
with gr.Row():
mask_image = gr.Image(label = "Mask image", visible = False)
submit.click(fn = update_seed, inputs = [
randomize_seed,
seed
], outputs = [
seed
], queue = False, show_progress = False).then(toggle_debug, debug_mode, [
original_image,
enlarged_image,
mask_image
], queue = False, show_progress = False).then(check, inputs = [
input_image,
enlarge_top,
enlarge_right,
enlarge_bottom,
enlarge_left,
prompt,
negative_prompt,
smooth_border,
num_inference_steps,
guidance_scale,
image_guidance_scale,
strength,
denoising_steps,
randomize_seed,
seed,
debug_mode
], outputs = [], queue = False,
show_progress = False).success(uncrop, inputs = [
input_image,
enlarge_top,
enlarge_right,
enlarge_bottom,
enlarge_left,
prompt,
negative_prompt,
smooth_border,
num_inference_steps,
guidance_scale,
image_guidance_scale,
strength,
denoising_steps,
randomize_seed,
seed,
debug_mode
], outputs = [
uncropped_image,
information,
original_image,
enlarged_image,
mask_image
], scroll_to_output = True)
gr.Examples(
run_on_click = True,
fn = uncrop,
inputs = [
input_image,
enlarge_top,
enlarge_right,
enlarge_bottom,
enlarge_left,
prompt,
negative_prompt,
smooth_border,
num_inference_steps,
guidance_scale,
image_guidance_scale,
strength,
denoising_steps,
randomize_seed,
seed,
debug_mode
],
outputs = [
uncropped_image,
information,
original_image,
enlarged_image,
mask_image
],
examples = [
[
"./examples/Coucang.jpg",
417,
0,
417,
0,
"A white Coucang, in a tree, ultrarealistic, realistic, photorealistic, 8k, bokeh",
"Border, frame, painting, drawing, cartoon, anime, 3d, scribbling, smear, noise, blur, watermark",
0,
50,
7,
1.5,
0.99,
1000,
False,
123,
False
],
],
cache_examples = False,
)
gr.Markdown(
"""
## Credit
The [example image](https://commons.wikimedia.org/wiki/File:Coucang.jpg) is by Aprisonsan
and licensed under CC-BY-SA 4.0 International.
"""
)
interface.queue().launch()