r3gm's picture
Super-squash branch 'main' using huggingface_hub
164b1a9 verified
import os
os.system("cp opencv.pc /usr/local/lib/pkgconfig/")
os.system("pip install 'numpy<2'")
os.system("pip uninstall triton -y")
import spaces
import io
import base64
import sys
import numpy as np
import torch
from PIL import Image, ImageOps
import gradio as gr
import skimage
import skimage.measure
import yaml
import json
from enum import Enum
from utils import *
from collections import Counter
import argparse
from stablepy import Model_Diffusers, scheduler_names, ALL_PROMPT_WEIGHT_OPTIONS, SCHEDULE_TYPE_OPTIONS
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars
from datetime import datetime
parser = argparse.ArgumentParser(description="stablediffusion-infinity")
parser.add_argument("--port", type=int, help="listen port", dest="server_port")
parser.add_argument("--host", type=str, help="host", dest="server_name")
parser.add_argument("--share", action="store_true", help="share this app?")
parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument("--fp32", action="store_true", help="using full precision")
parser.add_argument("--lowvram", action="store_true", help="using lowvram mode")
parser.add_argument("--encrypt", action="store_true", help="using https?")
parser.add_argument("--ssl_keyfile", type=str, help="path to ssl_keyfile")
parser.add_argument("--ssl_certfile", type=str, help="path to ssl_certfile")
parser.add_argument("--ssl_keyfile_password", type=str, help="ssl_keyfile_password")
parser.add_argument(
"--auth", nargs=2, metavar=("username", "password"), help="use username password"
)
parser.add_argument(
"--remote_model",
type=str,
help="use a model (e.g. dreambooth fined) from huggingface hub",
default="",
)
parser.add_argument(
"--local_model", type=str, help="use a model stored on your PC", default=""
)
parser.add_argument(
"--stablepy_model",
type=str,
help="Model source: can be a Hugging Face Diffusers repo or a local .safetensors file path",
default="SG161222/RealVisXL_V5.0_Lightning"
)
try:
abspath = os.path.abspath(__file__)
dirname = os.path.dirname(abspath)
os.chdir(dirname)
except:
pass
Interrogator = DummyInterrogator
START_DEVICE_STABLEPY = "cpu" if os.getenv("SPACES_ZERO_GPU") else None
DEBUG_MODE = False
AUTO_SETUP = True
with open("config.yaml", "r") as yaml_in:
yaml_object = yaml.safe_load(yaml_in)
config_json = json.dumps(yaml_object)
def parse_color(color):
"""
Convert color to Pillow-friendly (R, G, B, A) tuple in 0–255 range.
Supports:
- tuple/list of floats or ints
- 'rgba(r, g, b, a)' string
- 'rgb(r, g, b)' string
- hex colors: '#RRGGBB' or '#RRGGBBAA'
"""
if isinstance(color, (tuple, list)):
parts = [float(c) for c in color]
elif isinstance(color, str):
c = color.strip().lower()
# Hex color
if c.startswith("#"):
c = c.lstrip("#")
if len(c) == 6: # RRGGBB
r, g, b = int(c[0:2], 16), int(c[2:4], 16), int(c[4:6], 16)
return (r, g, b, 255)
elif len(c) == 8: # RRGGBBAA
r, g, b, a = int(c[0:2], 16), int(c[2:4], 16), int(c[4:6], 16), int(c[6:8], 16)
return (r, g, b, a)
else:
raise ValueError(f"Invalid hex color: {color}")
# RGB / RGBA string
c = c.replace("rgba", "").replace("rgb", "").replace("(", "").replace(")", "")
parts = [float(x.strip()) for x in c.split(",")]
else:
raise ValueError(f"Unsupported color format: {color}")
# Ensure alpha
if len(parts) == 3:
parts.append(1.0) # default alpha = 1.0
return (
int(round(parts[0])),
int(round(parts[1])),
int(round(parts[2])),
int(round(parts[3] * 255 if parts[3] <= 1 else parts[3]))
)
def is_not_dark(color, threshold=30):
return not all(c <= threshold for c in color)
def get_dominant_color_exclude_dark(image_pil):
img_small = image_pil.convert("RGB").resize((50, 50))
pixels = list(img_small.getdata())
filtered_pixels = [p for p in pixels if is_not_dark(p)]
if not filtered_pixels:
filtered_pixels = pixels
most_common = Counter(filtered_pixels).most_common(1)[0][0]
return most_common
def replace_color_in_mask(image_pil, mask_pil, target_color=None):
img = np.array(image_pil.convert("RGB"))
mask = np.array(mask_pil.convert("L"))
mask_white = mask == 255
mask_nonwhite = ~mask_white
if target_color in [None, ""]:
nonwhite_pixels = img[mask_nonwhite]
nonwhite_img = Image.fromarray(nonwhite_pixels.reshape((-1, 1, 3))) # , mode="RGB"
target_color = get_dominant_color_exclude_dark(nonwhite_img)
else:
parsed = parse_color(target_color) # (R, G, B, A)
target_color = parsed[:3] # ignore alpha for replacement
img[mask_white] = target_color
return Image.fromarray(img)
def expand_white_around_black(image: Image.Image, expand_ratio=0.1) -> Image.Image:
"""
Expand the white areas around the black region by a percentage of the black region size.
Args:
image: PIL grayscale image (mode "L").
expand_ratio: Fraction of black region size to expand white sides (default 0.1 = 10%).
Returns:
PIL Image with white expanded around black.
"""
arr = np.array(image)
black_mask = arr == 0
height, width = arr.shape
coords = np.argwhere(black_mask)
if coords.size == 0:
# No black pixels, return original image
return image.copy()
y_min, x_min = coords.min(axis=0)
y_max, x_max = coords.max(axis=0)
expand_x = int((x_max - x_min + 1) * expand_ratio)
expand_y = int((y_max - y_min + 1) * expand_ratio)
# Shrink black bounding box to expand white sides
if y_min > 0 and np.all(arr[:y_min, :] == 255):
y_min = min(height - 1, y_min + expand_y)
if y_max < height - 1 and np.all(arr[y_max + 1:, :] == 255):
y_max = max(0, y_max - expand_y)
if x_min > 0 and np.all(arr[:, :x_min] == 255):
x_min = min(width - 1, x_min + expand_x)
if x_max < width - 1 and np.all(arr[:, x_max + 1:] == 255):
x_max = max(0, x_max - expand_x)
# Create new white canvas
expanded_arr = np.full_like(arr, 255)
# Paint black inside adjusted bounding box
expanded_arr[y_min:y_max+1, x_min:x_max+1] = 0
return Image.fromarray(expanded_arr)
def load_html():
body, canvaspy = "", ""
with open("index.html", encoding="utf8") as f:
body = f.read()
with open("canvas.py", encoding="utf8") as f:
canvaspy = f.read()
body = body.replace("- paths:\n", "")
body = body.replace(" - ./canvas.py\n", "")
body = body.replace("from canvas import InfCanvas", canvaspy)
return body
def test(x):
x = load_html()
return f"""<iframe id="sdinfframe" style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media; vertical-scroll 'none'" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
try:
SAMPLING_MODE = Image.Resampling.LANCZOS
except Exception as e:
SAMPLING_MODE = Image.LANCZOS
try:
contain_func = ImageOps.contain
except Exception as e:
def contain_func(image, size, method=SAMPLING_MODE):
# from PIL: https://pillow.readthedocs.io/en/stable/reference/ImageOps.html#PIL.ImageOps.contain
im_ratio = image.width / image.height
dest_ratio = size[0] / size[1]
if im_ratio != dest_ratio:
if im_ratio > dest_ratio:
new_height = int(image.height / image.width * size[0])
if new_height != size[1]:
size = (size[0], new_height)
else:
new_width = int(image.width / image.height * size[1])
if new_width != size[0]:
size = (new_width, size[1])
return image.resize(size, resample=method)
if __name__ == "__main__":
args = parser.parse_args()
else:
args = parser.parse_args(["--debug"])
# args = parser.parse_args(["--debug"])
if args.auth is not None:
args.auth = tuple(args.auth)
model = {}
def get_token():
token = ""
if os.path.exists(".token"):
with open(".token", "r") as f:
token = f.read()
token = os.environ.get("hftoken", token)
return token
def save_token(token):
with open(".token", "w") as f:
f.write(token)
def my_resize(width, height):
if width >= 512 and height >= 512:
return width, height
if width == height:
return 512, 512
smaller = min(width, height)
larger = max(width, height)
if larger >= 608:
return width, height
factor = 1
if smaller < 290:
factor = 2
elif smaller < 330:
factor = 1.75
elif smaller < 384:
factor = 1.375
elif smaller < 400:
factor = 1.25
elif smaller < 450:
factor = 1.125
return int(factor * width) // 8 * 8, int(factor * height) // 8 * 8
def load_learned_embed_in_clip(
learned_embeds_path, text_encoder, tokenizer, token=None
):
# https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
# separate token and the embeds
trained_token = list(loaded_learned_embeds.keys())[0]
embeds = loaded_learned_embeds[trained_token]
# cast to dtype of text_encoder
dtype = text_encoder.get_input_embeddings().weight.dtype
embeds.to(dtype)
# add the token in tokenizer
token = token if token is not None else trained_token
num_added_tokens = tokenizer.add_tokens(token)
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
)
# resize the token embeddings
text_encoder.resize_token_embeddings(len(tokenizer))
# get the id for the token and assign the embeds
token_id = tokenizer.convert_tokens_to_ids(token)
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
MODEL_NAME = args.stablepy_model
print(f"Loading model {MODEL_NAME}. This may take some time if it is a Diffusers-format model.")
LOAD_PIPE_ARGS = dict(
vae_model=None,
retain_task_model_in_cache=True,
controlnet_model="Automatic",
type_model_precision=torch.float16,
)
disable_progress_bars()
base_model = Model_Diffusers(
base_model_id=MODEL_NAME,
task_name="repaint",
device=START_DEVICE_STABLEPY,
**LOAD_PIPE_ARGS,
)
enable_progress_bars()
if START_DEVICE_STABLEPY:
base_model.device = torch.device("cuda:0")
base_model.pipe.to(torch.device("cuda:0"), torch.float16)
# maybe a second base_model for anime images
class StableDiffusion:
def __init__(
self,
token: str = "",
model_name: str = "stable-diffusion-v1-5/stable-diffusion-v1-5",
model_path: str = None,
inpainting_model: bool = False,
**kwargs,
):
if DEBUG_MODE:
print("sd task selection")
def run(
self,
image_pil,
prompt="",
negative_prompt="",
guidance_scale=7.5,
resize_check=True,
enable_safety=True,
fill_mode="patchmatch",
strength=0.75,
step=50,
enable_img2img=False,
use_seed=False,
seed_val=-1,
generate_num=1,
scheduler="",
scheduler_eta=0.0,
controlnet_union=True,
expand_mask_percent=0.1,
color_selector_=None,
scheduler_type="Automatic",
prompt_weight="Classic",
image_resolution=1024,
img_height=1024,
img_width=1024,
loraA=None,
loraAscale=1.,
**kwargs,
):
global base_model
width, height = image_pil.size
if DEBUG_MODE:
image_pil.save(
f"output_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
)
print(image_pil.size)
sel_buffer = np.array(image_pil)
img = sel_buffer[:, :, 0:3]
mask = sel_buffer[:, :, -1]
nmask = 255 - mask
process_width = width
process_height = height
extra_kwargs = {
"num_steps": step,
"guidance_scale": guidance_scale,
"sampler": scheduler,
"num_images": generate_num,
"negative_prompt": negative_prompt,
"seed": (seed_val if use_seed else -1),
"strength": strength,
"schedule_type": scheduler_type,
"syntax_weights": prompt_weight,
"lora_A": (loraA if loraA != "None" else None),
"lora_scale_A": loraAscale,
}
if resize_check:
process_width, process_height = my_resize(width, height)
extra_kwargs["image_resolution"] = 1024
else:
extra_kwargs["image_resolution"] = image_resolution
if nmask.sum() < 1 and enable_img2img:
# Img2img
init_image = Image.fromarray(img)
base_model.load_pipe(
base_model_id=MODEL_NAME,
task_name="img2img",
**LOAD_PIPE_ARGS,
)
images = base_model(
prompt=prompt,
image=init_image.resize(
(process_width, process_height), resample=SAMPLING_MODE
),
strength=strength,
**extra_kwargs,
)[0]
elif mask.sum() > 0:
if fill_mode == "g_diffuser" or "_color" in fill_mode:
mask = 255 - mask
mask = mask[:, :, np.newaxis].repeat(3, axis=2)
if "_color" not in fill_mode:
img, mask = functbl[fill_mode](img, mask)
# extra_kwargs["out_mask"] = Image.fromarray(mask)
# inpaint_func = unified
else:
img, mask = functbl[fill_mode](img, mask)
mask = 255 - mask
mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
mask = mask.repeat(8, axis=0).repeat(8, axis=1)
# inpaint_func = inpaint
init_image = Image.fromarray(img)
mask_image = Image.fromarray(mask)
# mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
input_image = init_image.resize(
(process_width, process_height), resample=SAMPLING_MODE
)
if DEBUG_MODE:
init_image.save(
f"init_image_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
)
print(init_image.size)
mask_image.save(
f"mask_image_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
)
print(mask_image.size)
if fill_mode == "pad_common_color":
init_image = replace_color_in_mask(init_image, mask_image, None)
elif fill_mode == "pad_selected_color":
init_image = replace_color_in_mask(init_image, mask_image, color_selector_)
if expand_mask_percent:
if mask_image.mode != "L":
if DEBUG_MODE:
print("convert to L")
mask_image = mask_image.convert("L")
mask_image = expand_white_around_black(mask_image, expand_ratio=expand_mask_percent)
mask_image.save(
f"mask_image_expanded_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
)
if DEBUG_MODE:
print(mask_image.size)
if controlnet_union:
# Outpaint
base_model.load_pipe(
base_model_id=MODEL_NAME,
task_name="repaint",
**LOAD_PIPE_ARGS,
)
images = base_model(
prompt=prompt,
image=input_image,
img_width=process_width,
img_height=process_height,
image_mask=mask_image.resize((process_width, process_height)),
**extra_kwargs,
)[0]
else:
# Inpaint
base_model.load_pipe(
base_model_id=MODEL_NAME,
task_name="inpaint",
**LOAD_PIPE_ARGS,
)
images = base_model(
prompt=prompt,
image=input_image,
image_mask=mask_image.resize((process_width, process_height)),
**extra_kwargs,
)[0]
else:
# txt2img
base_model.load_pipe(
base_model_id=MODEL_NAME,
task_name="txt2img",
**LOAD_PIPE_ARGS,
)
images = base_model(
prompt=prompt,
img_height=img_height,
img_width=img_width,
**extra_kwargs,
)[0]
if DEBUG_MODE:
print(f"TASK NAME {base_model.task_name}")
return images
@spaces.GPU(duration=15)
def generate_images(
cur_model,
pil,
prompt_text,
negative_prompt_text,
guidance,
strength,
step,
resize_check,
fill_mode,
enable_safety,
use_seed,
seed_val,
generate_num,
scheduler,
scheduler_eta,
enable_img2img,
width,
height,
controlnet_union,
expand_mask,
color_selector_,
scheduler_type,
prompt_weight,
image_resolution,
img_height,
img_width,
loraA,
loraAscale,
):
return cur_model.run(
image_pil=pil,
prompt=prompt_text,
negative_prompt=negative_prompt_text,
guidance_scale=guidance,
strength=strength,
step=step,
resize_check=resize_check,
fill_mode=fill_mode,
enable_safety=enable_safety,
use_seed=use_seed,
seed_val=seed_val,
generate_num=generate_num,
scheduler=scheduler,
scheduler_eta=scheduler_eta,
enable_img2img=enable_img2img,
width=width,
height=height,
controlnet_union=controlnet_union,
expand_mask_percent=expand_mask,
color_selector_=color_selector_,
scheduler_type=scheduler_type,
prompt_weight=prompt_weight,
image_resolution=image_resolution,
img_height=img_height,
img_width=img_width,
loraA=loraA,
loraAscale=loraAscale,
)
def run_outpaint(
sel_buffer_str,
prompt_text,
negative_prompt_text,
strength,
guidance,
step,
resize_check,
fill_mode,
enable_safety,
use_correction,
enable_img2img,
use_seed,
seed_val,
generate_num,
scheduler,
scheduler_eta,
controlnet_union,
expand_mask,
color_selector_,
scheduler_type,
prompt_weight,
image_resolution,
img_height,
img_width,
loraA,
loraAscale,
interrogate_mode,
state,
):
if DEBUG_MODE:
print("start proceed")
data = base64.b64decode(str(sel_buffer_str))
pil = Image.open(io.BytesIO(data))
if interrogate_mode:
if "interrogator" not in model:
model["interrogator"] = Interrogator()
interrogator = model["interrogator"]
img = np.array(pil)[:, :, 0:3]
mask = np.array(pil)[:, :, -1]
x, y = np.nonzero(mask)
if len(x) > 0:
x0, x1 = x.min(), x.max() + 1
y0, y1 = y.min(), y.max() + 1
img = img[x0:x1, y0:y1, :]
pil = Image.fromarray(img)
interrogate_ret = interrogator.interrogate(pil)
return (
gr.update(value=",".join([sel_buffer_str]),),
gr.update(label="Prompt", value=interrogate_ret),
state,
)
width, height = pil.size
sel_buffer = np.array(pil)
cur_model = StableDiffusion()
if DEBUG_MODE:
print("start inference")
images = generate_images(
cur_model,
pil,
prompt_text,
negative_prompt_text,
guidance,
strength,
step,
resize_check,
fill_mode,
enable_safety,
use_seed,
seed_val,
generate_num,
scheduler,
scheduler_eta,
enable_img2img,
width,
height,
controlnet_union,
expand_mask,
color_selector_,
scheduler_type,
prompt_weight,
image_resolution,
img_height,
img_width,
loraA,
loraAscale,
)
if DEBUG_MODE:
print("return result")
base64_str_lst = []
if enable_img2img:
use_correction = "border_mode"
for image in images:
image = correction_func.run(pil.resize(image.size), image, mode=use_correction)
resized_img = image.resize((width, height), resample=SAMPLING_MODE,)
out = sel_buffer.copy()
out[:, :, 0:3] = np.array(resized_img)
out[:, :, -1] = 255
out_pil = Image.fromarray(out)
out_buffer = io.BytesIO()
out_pil.save(out_buffer, format="PNG")
out_buffer.seek(0)
base64_bytes = base64.b64encode(out_buffer.read())
base64_str = base64_bytes.decode("ascii")
base64_str_lst.append(base64_str)
return (
gr.update(label=str(state + 1), value=",".join(base64_str_lst),),
gr.update(label="Prompt"),
state + 1,
)
generate_images.zerogpu = True
run_outpaint.zerogpu = True
def load_js(name):
if name in ["export", "commit", "undo"]:
return f"""
function (x)
{{
let app=document.querySelector("gradio-app");
app=app.shadowRoot??app;
let frame=app.querySelector("#sdinfframe").contentWindow.document;
let button=frame.querySelector("#{name}");
button.click();
return x;
}}
"""
ret = ""
with open(f"./js/{name}.js", "r") as f:
ret = f.read()
return ret
proceed_button_js = load_js("proceed")
setup_button_js = load_js("setup")
blocks = gr.Blocks(
title="StableDiffusion-Infinity",
css="""
.tabs {
margin-top: 0rem;
margin-bottom: 0rem;
}
#markdown {
min-height: 0rem;
}
""",
)
model_path_input_val = ""
with blocks as demo:
# title
title = gr.Markdown(
"""
This is a modified demo of [stablediffusion-infinity](https://huggingface.co/spaces/lnyan/stablediffusion-infinity) with SDXL support.
**stablediffusion-infinity**: Outpainting with Stable Diffusion on an infinite canvas: [https://github.com/lkwq007/stablediffusion-infinity](https://github.com/lkwq007/stablediffusion-infinity)
""",
elem_id="markdown",
)
# frame
frame = gr.HTML(test(2), visible=True)
# setup
if not AUTO_SETUP:
model_choices_lst = [""]
if args.local_model:
model_path_input_val = args.local_model
# model_choices_lst.insert(0, "local_model")
elif args.remote_model:
model_path_input_val = args.remote_model
# model_choices_lst.insert(0, "remote_model")
with gr.Row(elem_id="setup_row"):
with gr.Column(scale=4, min_width=350):
token = gr.Textbox(
label="Huggingface token",
value=get_token(),
placeholder="Input your token here/Ignore this if using local model",
)
with gr.Column(scale=3, min_width=320):
model_selection = gr.Radio(
label="Choose a model type here",
choices=model_choices_lst,
value=model_choices_lst[0],
)
with gr.Column(scale=1, min_width=100):
canvas_width = gr.Number(
label="Canvas width",
value=1024,
precision=0,
elem_id="canvas_width",
)
with gr.Column(scale=1, min_width=100):
canvas_height = gr.Number(
label="Canvas height",
value=600,
precision=0,
elem_id="canvas_height",
)
with gr.Column(scale=1, min_width=100):
selection_size = gr.Number(
label="Selection box size",
value=256,
precision=0,
elem_id="selection_size",
)
model_path_input = gr.Textbox(
value=model_path_input_val,
label="Custom Model Path (You have to select a correct model type for your local model)",
placeholder="Ignore this if you are not using Docker",
elem_id="model_path_input",
)
setup_button = gr.Button("Click to Setup (may take a while)", variant="primary")
with gr.Row():
with gr.Column(scale=3, min_width=270):
init_mode = gr.Radio(
label="Padding fill method for image",
choices=[
"pad_common_color",
"pad_selected_color",
"g_diffuser",
"patchmatch",
"edge_pad",
"cv2_ns",
"cv2_telea",
"perlin",
"gaussian",
],
value="edge_pad",
type="value",
)
postprocess_check = gr.Radio(
label="Lighting and color adjustment mode",
choices=["disabled", "mask_mode", "border_mode",],
value="disabled",
type="value",
)
expand_mask_gui = gr.Slider(.0, .5, value=0.1, step=0.01, label="Mask Expansion (%)", info="Change how far the mask reaches from the edges of the image. Only if pad_selected_color is selected. ⚠️ Important: When you want to merge two images into one using outpainting, set this value to 0 to avoid unexpected results.")
color_selector = gr.ColorPicker(value="#FFFFFF", label="Color for `pad_selected_color`", info="Choose the color used to fill the extended padding area. ")
with gr.Column(scale=3, min_width=270):
sd_prompt = gr.Textbox(
label="Prompt", placeholder="input your prompt here!", lines=4
)
sd_negative_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="input your negative prompt here!",
lines=4,
)
with gr.Column(scale=2, min_width=150):
with gr.Group():
with gr.Row():
sd_strength = gr.Slider(
label="Strength",
minimum=0.0,
maximum=1.0,
value=1.0,
step=0.01,
)
with gr.Row():
sd_scheduler = gr.Dropdown(
scheduler_names,
value="TCD",
label="Sampler",
)
sd_scheduler_type = gr.Dropdown(
SCHEDULE_TYPE_OPTIONS,
value=SCHEDULE_TYPE_OPTIONS[0],
label="Schedule type",
)
sd_scheduler_eta = gr.Number(label="Eta", value=0.0, visible=False)
sd_controlnet_union = gr.Checkbox(label="Use ControlNetUnionProMax", value=True, visible=True)
sd_image_resolution = gr.Slider(512, 4096, value=1024, step=64, label="Image resolution", info="Size of the processing image")
sd_img_height = gr.Slider(512, 4096, value=1024, step=64, label="Height for txt2img", info="Used if no image is in the selected canvas area.", visible=False)
sd_img_width = gr.Slider(512, 4096, value=1024, step=64, label="Width for txt2img", info="Used if no image is in the selected canvas area.", visible=False)
with gr.Column(scale=1, min_width=80):
sd_generate_num = gr.Number(label="Sample number", minimum=1, maximum=10, value=1)
sd_step = gr.Number(label="Step", value=12, minimum=2)
sd_guidance = gr.Number(label="Guidance scale", value=1.5, step=0.5)
sd_prompt_weight = gr.Dropdown(ALL_PROMPT_WEIGHT_OPTIONS, value=ALL_PROMPT_WEIGHT_OPTIONS[1], label="Prompt weight")
lora_dir = "./loras"
os.makedirs(lora_dir, exist_ok=True)
lora_files = [
f for f in os.listdir(lora_dir)
if os.path.isfile(os.path.join(lora_dir, f))
]
lora_files.insert(0, "None")
sd_loraA = gr.Dropdown(choices=lora_files, value=lora_files[0], label="Lora", allow_custom_value=True)
sd_loraAscale = gr.Slider(-2., 2., value=1., step=0.01, label="Lora scale")
proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE)
xss_js = load_js("xss").replace("\n", " ")
xss_html = gr.HTML(
value=f"""
<img src='hts://not.exist' onerror='{xss_js}'>""",
visible=False,
)
xss_keyboard_js = load_js("keyboard").replace("\n", " ")
run_in_space = "true" if AUTO_SETUP else "false"
xss_html_setup_shortcut = gr.HTML(
value=f"""
<img src='htts://not.exist' onerror='window.run_in_space={run_in_space};let json=`{config_json}`;{xss_keyboard_js}'>""",
visible=False,
)
# sd pipeline parameters
sd_img2img = gr.Checkbox(label="Enable Img2Img", value=False, visible=False)
sd_resize = gr.Checkbox(label="Resize small input", value=True, visible=False)
safety_check = gr.Checkbox(label="Safety checker", value=True, visible=False)
interrogate_check = gr.Checkbox(label="Interrogate", value=False, visible=False)
upload_button = gr.Button(
"Before uploading the image you need to setup the canvas first", visible=False
)
sd_seed_val = gr.Number(label="Seed", value=0, precision=0, visible=False)
sd_use_seed = gr.Checkbox(label="Use seed", value=False, visible=False)
model_output = gr.Textbox(visible=DEBUG_MODE, elem_id="output", label="0")
model_input = gr.Textbox(visible=DEBUG_MODE, elem_id="input", label="Input")
upload_output = gr.Textbox(visible=DEBUG_MODE, elem_id="upload", label="0")
model_output_state = gr.State(value=0)
upload_output_state = gr.State(value=0)
cancel_button = gr.Button("Cancel", elem_id="cancel", visible=False)
if not AUTO_SETUP:
def setup_func(token_val, width, height, size, model_choice, model_path):
try:
StableDiffusion()
except Exception as e:
print(e)
return {token: gr.update(value=str(e))}
init_val = "patchmatch"
return {
token: gr.update(visible=False),
canvas_width: gr.update(visible=False),
canvas_height: gr.update(visible=False),
selection_size: gr.update(visible=False),
setup_button: gr.update(visible=False),
frame: gr.update(visible=True),
upload_button: gr.update(value="Upload Image"),
model_selection: gr.update(visible=False),
model_path_input: gr.update(visible=False),
init_mode: gr.update(value=init_val),
}
setup_button.click(
fn=setup_func,
inputs=[
token,
canvas_width,
canvas_height,
selection_size,
model_selection,
model_path_input,
],
outputs=[
token,
canvas_width,
canvas_height,
selection_size,
setup_button,
frame,
upload_button,
model_selection,
model_path_input,
init_mode,
],
js=setup_button_js,
)
proceed_event = proceed_button.click(
fn=run_outpaint,
inputs=[
model_input,
sd_prompt,
sd_negative_prompt,
sd_strength,
sd_guidance,
sd_step,
sd_resize,
init_mode,
safety_check,
postprocess_check,
sd_img2img,
sd_use_seed,
sd_seed_val,
sd_generate_num,
sd_scheduler,
sd_scheduler_eta,
sd_controlnet_union,
expand_mask_gui,
color_selector,
sd_scheduler_type,
sd_prompt_weight,
sd_image_resolution,
sd_img_height,
sd_img_width,
sd_loraA,
sd_loraAscale,
interrogate_check,
model_output_state,
],
outputs=[model_output, sd_prompt, model_output_state],
js=proceed_button_js,
)
# cancel button can also remove error overlay
if tuple(map(int,gr.__version__.split("."))) >= (3,6):
cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[proceed_event])
launch_extra_kwargs = {
"show_error": True,
# "favicon_path": ""
}
launch_kwargs = vars(args)
launch_kwargs = {k: v for k, v in launch_kwargs.items() if v is not None}
launch_kwargs.pop("remote_model", None)
launch_kwargs.pop("local_model", None)
launch_kwargs.pop("fp32", None)
launch_kwargs.pop("lowvram", None)
launch_kwargs.pop("stablepy_model", None)
launch_kwargs.update(launch_extra_kwargs)
try:
import google.colab
launch_kwargs["debug"] = True
launch_kwargs["share"] = True
launch_kwargs.pop("encrypt", None)
except:
launch_kwargs["share"] = False
pass
if not launch_kwargs["share"]:
demo.launch()
else:
launch_kwargs["server_name"] = "0.0.0.0"
demo.queue().launch(**launch_kwargs)