Spaces:
Running
on
Zero
Running
on
Zero
import os | |
os.system("cp opencv.pc /usr/local/lib/pkgconfig/") | |
os.system("pip install 'numpy<2'") | |
os.system("pip uninstall triton -y") | |
import spaces | |
import io | |
import base64 | |
import sys | |
import numpy as np | |
import torch | |
from PIL import Image, ImageOps | |
import gradio as gr | |
import skimage | |
import skimage.measure | |
import yaml | |
import json | |
from enum import Enum | |
from utils import * | |
from collections import Counter | |
import argparse | |
from stablepy import Model_Diffusers, scheduler_names, ALL_PROMPT_WEIGHT_OPTIONS, SCHEDULE_TYPE_OPTIONS | |
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars | |
from datetime import datetime | |
parser = argparse.ArgumentParser(description="stablediffusion-infinity") | |
parser.add_argument("--port", type=int, help="listen port", dest="server_port") | |
parser.add_argument("--host", type=str, help="host", dest="server_name") | |
parser.add_argument("--share", action="store_true", help="share this app?") | |
parser.add_argument("--debug", action="store_true", help="debug mode") | |
parser.add_argument("--fp32", action="store_true", help="using full precision") | |
parser.add_argument("--lowvram", action="store_true", help="using lowvram mode") | |
parser.add_argument("--encrypt", action="store_true", help="using https?") | |
parser.add_argument("--ssl_keyfile", type=str, help="path to ssl_keyfile") | |
parser.add_argument("--ssl_certfile", type=str, help="path to ssl_certfile") | |
parser.add_argument("--ssl_keyfile_password", type=str, help="ssl_keyfile_password") | |
parser.add_argument( | |
"--auth", nargs=2, metavar=("username", "password"), help="use username password" | |
) | |
parser.add_argument( | |
"--remote_model", | |
type=str, | |
help="use a model (e.g. dreambooth fined) from huggingface hub", | |
default="", | |
) | |
parser.add_argument( | |
"--local_model", type=str, help="use a model stored on your PC", default="" | |
) | |
parser.add_argument( | |
"--stablepy_model", | |
type=str, | |
help="Model source: can be a Hugging Face Diffusers repo or a local .safetensors file path", | |
default="SG161222/RealVisXL_V5.0_Lightning" | |
) | |
try: | |
abspath = os.path.abspath(__file__) | |
dirname = os.path.dirname(abspath) | |
os.chdir(dirname) | |
except: | |
pass | |
Interrogator = DummyInterrogator | |
START_DEVICE_STABLEPY = "cpu" if os.getenv("SPACES_ZERO_GPU") else None | |
DEBUG_MODE = False | |
AUTO_SETUP = True | |
with open("config.yaml", "r") as yaml_in: | |
yaml_object = yaml.safe_load(yaml_in) | |
config_json = json.dumps(yaml_object) | |
def parse_color(color): | |
""" | |
Convert color to Pillow-friendly (R, G, B, A) tuple in 0–255 range. | |
Supports: | |
- tuple/list of floats or ints | |
- 'rgba(r, g, b, a)' string | |
- 'rgb(r, g, b)' string | |
- hex colors: '#RRGGBB' or '#RRGGBBAA' | |
""" | |
if isinstance(color, (tuple, list)): | |
parts = [float(c) for c in color] | |
elif isinstance(color, str): | |
c = color.strip().lower() | |
# Hex color | |
if c.startswith("#"): | |
c = c.lstrip("#") | |
if len(c) == 6: # RRGGBB | |
r, g, b = int(c[0:2], 16), int(c[2:4], 16), int(c[4:6], 16) | |
return (r, g, b, 255) | |
elif len(c) == 8: # RRGGBBAA | |
r, g, b, a = int(c[0:2], 16), int(c[2:4], 16), int(c[4:6], 16), int(c[6:8], 16) | |
return (r, g, b, a) | |
else: | |
raise ValueError(f"Invalid hex color: {color}") | |
# RGB / RGBA string | |
c = c.replace("rgba", "").replace("rgb", "").replace("(", "").replace(")", "") | |
parts = [float(x.strip()) for x in c.split(",")] | |
else: | |
raise ValueError(f"Unsupported color format: {color}") | |
# Ensure alpha | |
if len(parts) == 3: | |
parts.append(1.0) # default alpha = 1.0 | |
return ( | |
int(round(parts[0])), | |
int(round(parts[1])), | |
int(round(parts[2])), | |
int(round(parts[3] * 255 if parts[3] <= 1 else parts[3])) | |
) | |
def is_not_dark(color, threshold=30): | |
return not all(c <= threshold for c in color) | |
def get_dominant_color_exclude_dark(image_pil): | |
img_small = image_pil.convert("RGB").resize((50, 50)) | |
pixels = list(img_small.getdata()) | |
filtered_pixels = [p for p in pixels if is_not_dark(p)] | |
if not filtered_pixels: | |
filtered_pixels = pixels | |
most_common = Counter(filtered_pixels).most_common(1)[0][0] | |
return most_common | |
def replace_color_in_mask(image_pil, mask_pil, target_color=None): | |
img = np.array(image_pil.convert("RGB")) | |
mask = np.array(mask_pil.convert("L")) | |
mask_white = mask == 255 | |
mask_nonwhite = ~mask_white | |
if target_color in [None, ""]: | |
nonwhite_pixels = img[mask_nonwhite] | |
nonwhite_img = Image.fromarray(nonwhite_pixels.reshape((-1, 1, 3))) # , mode="RGB" | |
target_color = get_dominant_color_exclude_dark(nonwhite_img) | |
else: | |
parsed = parse_color(target_color) # (R, G, B, A) | |
target_color = parsed[:3] # ignore alpha for replacement | |
img[mask_white] = target_color | |
return Image.fromarray(img) | |
def expand_white_around_black(image: Image.Image, expand_ratio=0.1) -> Image.Image: | |
""" | |
Expand the white areas around the black region by a percentage of the black region size. | |
Args: | |
image: PIL grayscale image (mode "L"). | |
expand_ratio: Fraction of black region size to expand white sides (default 0.1 = 10%). | |
Returns: | |
PIL Image with white expanded around black. | |
""" | |
arr = np.array(image) | |
black_mask = arr == 0 | |
height, width = arr.shape | |
coords = np.argwhere(black_mask) | |
if coords.size == 0: | |
# No black pixels, return original image | |
return image.copy() | |
y_min, x_min = coords.min(axis=0) | |
y_max, x_max = coords.max(axis=0) | |
expand_x = int((x_max - x_min + 1) * expand_ratio) | |
expand_y = int((y_max - y_min + 1) * expand_ratio) | |
# Shrink black bounding box to expand white sides | |
if y_min > 0 and np.all(arr[:y_min, :] == 255): | |
y_min = min(height - 1, y_min + expand_y) | |
if y_max < height - 1 and np.all(arr[y_max + 1:, :] == 255): | |
y_max = max(0, y_max - expand_y) | |
if x_min > 0 and np.all(arr[:, :x_min] == 255): | |
x_min = min(width - 1, x_min + expand_x) | |
if x_max < width - 1 and np.all(arr[:, x_max + 1:] == 255): | |
x_max = max(0, x_max - expand_x) | |
# Create new white canvas | |
expanded_arr = np.full_like(arr, 255) | |
# Paint black inside adjusted bounding box | |
expanded_arr[y_min:y_max+1, x_min:x_max+1] = 0 | |
return Image.fromarray(expanded_arr) | |
def load_html(): | |
body, canvaspy = "", "" | |
with open("index.html", encoding="utf8") as f: | |
body = f.read() | |
with open("canvas.py", encoding="utf8") as f: | |
canvaspy = f.read() | |
body = body.replace("- paths:\n", "") | |
body = body.replace(" - ./canvas.py\n", "") | |
body = body.replace("from canvas import InfCanvas", canvaspy) | |
return body | |
def test(x): | |
x = load_html() | |
return f"""<iframe id="sdinfframe" style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera; | |
display-capture; encrypted-media; vertical-scroll 'none'" sandbox="allow-modals allow-forms | |
allow-scripts allow-same-origin allow-popups | |
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" | |
try: | |
SAMPLING_MODE = Image.Resampling.LANCZOS | |
except Exception as e: | |
SAMPLING_MODE = Image.LANCZOS | |
try: | |
contain_func = ImageOps.contain | |
except Exception as e: | |
def contain_func(image, size, method=SAMPLING_MODE): | |
# from PIL: https://pillow.readthedocs.io/en/stable/reference/ImageOps.html#PIL.ImageOps.contain | |
im_ratio = image.width / image.height | |
dest_ratio = size[0] / size[1] | |
if im_ratio != dest_ratio: | |
if im_ratio > dest_ratio: | |
new_height = int(image.height / image.width * size[0]) | |
if new_height != size[1]: | |
size = (size[0], new_height) | |
else: | |
new_width = int(image.width / image.height * size[1]) | |
if new_width != size[0]: | |
size = (new_width, size[1]) | |
return image.resize(size, resample=method) | |
if __name__ == "__main__": | |
args = parser.parse_args() | |
else: | |
args = parser.parse_args(["--debug"]) | |
# args = parser.parse_args(["--debug"]) | |
if args.auth is not None: | |
args.auth = tuple(args.auth) | |
model = {} | |
def get_token(): | |
token = "" | |
if os.path.exists(".token"): | |
with open(".token", "r") as f: | |
token = f.read() | |
token = os.environ.get("hftoken", token) | |
return token | |
def save_token(token): | |
with open(".token", "w") as f: | |
f.write(token) | |
def my_resize(width, height): | |
if width >= 512 and height >= 512: | |
return width, height | |
if width == height: | |
return 512, 512 | |
smaller = min(width, height) | |
larger = max(width, height) | |
if larger >= 608: | |
return width, height | |
factor = 1 | |
if smaller < 290: | |
factor = 2 | |
elif smaller < 330: | |
factor = 1.75 | |
elif smaller < 384: | |
factor = 1.375 | |
elif smaller < 400: | |
factor = 1.25 | |
elif smaller < 450: | |
factor = 1.125 | |
return int(factor * width) // 8 * 8, int(factor * height) // 8 * 8 | |
def load_learned_embed_in_clip( | |
learned_embeds_path, text_encoder, tokenizer, token=None | |
): | |
# https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb | |
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") | |
# separate token and the embeds | |
trained_token = list(loaded_learned_embeds.keys())[0] | |
embeds = loaded_learned_embeds[trained_token] | |
# cast to dtype of text_encoder | |
dtype = text_encoder.get_input_embeddings().weight.dtype | |
embeds.to(dtype) | |
# add the token in tokenizer | |
token = token if token is not None else trained_token | |
num_added_tokens = tokenizer.add_tokens(token) | |
if num_added_tokens == 0: | |
raise ValueError( | |
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer." | |
) | |
# resize the token embeddings | |
text_encoder.resize_token_embeddings(len(tokenizer)) | |
# get the id for the token and assign the embeds | |
token_id = tokenizer.convert_tokens_to_ids(token) | |
text_encoder.get_input_embeddings().weight.data[token_id] = embeds | |
MODEL_NAME = args.stablepy_model | |
print(f"Loading model {MODEL_NAME}. This may take some time if it is a Diffusers-format model.") | |
LOAD_PIPE_ARGS = dict( | |
vae_model=None, | |
retain_task_model_in_cache=True, | |
controlnet_model="Automatic", | |
type_model_precision=torch.float16, | |
) | |
disable_progress_bars() | |
base_model = Model_Diffusers( | |
base_model_id=MODEL_NAME, | |
task_name="repaint", | |
device=START_DEVICE_STABLEPY, | |
**LOAD_PIPE_ARGS, | |
) | |
enable_progress_bars() | |
if START_DEVICE_STABLEPY: | |
base_model.device = torch.device("cuda:0") | |
base_model.pipe.to(torch.device("cuda:0"), torch.float16) | |
# maybe a second base_model for anime images | |
class StableDiffusion: | |
def __init__( | |
self, | |
token: str = "", | |
model_name: str = "stable-diffusion-v1-5/stable-diffusion-v1-5", | |
model_path: str = None, | |
inpainting_model: bool = False, | |
**kwargs, | |
): | |
if DEBUG_MODE: | |
print("sd task selection") | |
def run( | |
self, | |
image_pil, | |
prompt="", | |
negative_prompt="", | |
guidance_scale=7.5, | |
resize_check=True, | |
enable_safety=True, | |
fill_mode="patchmatch", | |
strength=0.75, | |
step=50, | |
enable_img2img=False, | |
use_seed=False, | |
seed_val=-1, | |
generate_num=1, | |
scheduler="", | |
scheduler_eta=0.0, | |
controlnet_union=True, | |
expand_mask_percent=0.1, | |
color_selector_=None, | |
scheduler_type="Automatic", | |
prompt_weight="Classic", | |
image_resolution=1024, | |
img_height=1024, | |
img_width=1024, | |
loraA=None, | |
loraAscale=1., | |
**kwargs, | |
): | |
global base_model | |
width, height = image_pil.size | |
if DEBUG_MODE: | |
image_pil.save( | |
f"output_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png" | |
) | |
print(image_pil.size) | |
sel_buffer = np.array(image_pil) | |
img = sel_buffer[:, :, 0:3] | |
mask = sel_buffer[:, :, -1] | |
nmask = 255 - mask | |
process_width = width | |
process_height = height | |
extra_kwargs = { | |
"num_steps": step, | |
"guidance_scale": guidance_scale, | |
"sampler": scheduler, | |
"num_images": generate_num, | |
"negative_prompt": negative_prompt, | |
"seed": (seed_val if use_seed else -1), | |
"strength": strength, | |
"schedule_type": scheduler_type, | |
"syntax_weights": prompt_weight, | |
"lora_A": (loraA if loraA != "None" else None), | |
"lora_scale_A": loraAscale, | |
} | |
if resize_check: | |
process_width, process_height = my_resize(width, height) | |
extra_kwargs["image_resolution"] = 1024 | |
else: | |
extra_kwargs["image_resolution"] = image_resolution | |
if nmask.sum() < 1 and enable_img2img: | |
# Img2img | |
init_image = Image.fromarray(img) | |
base_model.load_pipe( | |
base_model_id=MODEL_NAME, | |
task_name="img2img", | |
**LOAD_PIPE_ARGS, | |
) | |
images = base_model( | |
prompt=prompt, | |
image=init_image.resize( | |
(process_width, process_height), resample=SAMPLING_MODE | |
), | |
strength=strength, | |
**extra_kwargs, | |
)[0] | |
elif mask.sum() > 0: | |
if fill_mode == "g_diffuser" or "_color" in fill_mode: | |
mask = 255 - mask | |
mask = mask[:, :, np.newaxis].repeat(3, axis=2) | |
if "_color" not in fill_mode: | |
img, mask = functbl[fill_mode](img, mask) | |
# extra_kwargs["out_mask"] = Image.fromarray(mask) | |
# inpaint_func = unified | |
else: | |
img, mask = functbl[fill_mode](img, mask) | |
mask = 255 - mask | |
mask = skimage.measure.block_reduce(mask, (8, 8), np.max) | |
mask = mask.repeat(8, axis=0).repeat(8, axis=1) | |
# inpaint_func = inpaint | |
init_image = Image.fromarray(img) | |
mask_image = Image.fromarray(mask) | |
# mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8)) | |
input_image = init_image.resize( | |
(process_width, process_height), resample=SAMPLING_MODE | |
) | |
if DEBUG_MODE: | |
init_image.save( | |
f"init_image_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png" | |
) | |
print(init_image.size) | |
mask_image.save( | |
f"mask_image_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png" | |
) | |
print(mask_image.size) | |
if fill_mode == "pad_common_color": | |
init_image = replace_color_in_mask(init_image, mask_image, None) | |
elif fill_mode == "pad_selected_color": | |
init_image = replace_color_in_mask(init_image, mask_image, color_selector_) | |
if expand_mask_percent: | |
if mask_image.mode != "L": | |
if DEBUG_MODE: | |
print("convert to L") | |
mask_image = mask_image.convert("L") | |
mask_image = expand_white_around_black(mask_image, expand_ratio=expand_mask_percent) | |
mask_image.save( | |
f"mask_image_expanded_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png" | |
) | |
if DEBUG_MODE: | |
print(mask_image.size) | |
if controlnet_union: | |
# Outpaint | |
base_model.load_pipe( | |
base_model_id=MODEL_NAME, | |
task_name="repaint", | |
**LOAD_PIPE_ARGS, | |
) | |
images = base_model( | |
prompt=prompt, | |
image=input_image, | |
img_width=process_width, | |
img_height=process_height, | |
image_mask=mask_image.resize((process_width, process_height)), | |
**extra_kwargs, | |
)[0] | |
else: | |
# Inpaint | |
base_model.load_pipe( | |
base_model_id=MODEL_NAME, | |
task_name="inpaint", | |
**LOAD_PIPE_ARGS, | |
) | |
images = base_model( | |
prompt=prompt, | |
image=input_image, | |
image_mask=mask_image.resize((process_width, process_height)), | |
**extra_kwargs, | |
)[0] | |
else: | |
# txt2img | |
base_model.load_pipe( | |
base_model_id=MODEL_NAME, | |
task_name="txt2img", | |
**LOAD_PIPE_ARGS, | |
) | |
images = base_model( | |
prompt=prompt, | |
img_height=img_height, | |
img_width=img_width, | |
**extra_kwargs, | |
)[0] | |
if DEBUG_MODE: | |
print(f"TASK NAME {base_model.task_name}") | |
return images | |
def generate_images( | |
cur_model, | |
pil, | |
prompt_text, | |
negative_prompt_text, | |
guidance, | |
strength, | |
step, | |
resize_check, | |
fill_mode, | |
enable_safety, | |
use_seed, | |
seed_val, | |
generate_num, | |
scheduler, | |
scheduler_eta, | |
enable_img2img, | |
width, | |
height, | |
controlnet_union, | |
expand_mask, | |
color_selector_, | |
scheduler_type, | |
prompt_weight, | |
image_resolution, | |
img_height, | |
img_width, | |
loraA, | |
loraAscale, | |
): | |
return cur_model.run( | |
image_pil=pil, | |
prompt=prompt_text, | |
negative_prompt=negative_prompt_text, | |
guidance_scale=guidance, | |
strength=strength, | |
step=step, | |
resize_check=resize_check, | |
fill_mode=fill_mode, | |
enable_safety=enable_safety, | |
use_seed=use_seed, | |
seed_val=seed_val, | |
generate_num=generate_num, | |
scheduler=scheduler, | |
scheduler_eta=scheduler_eta, | |
enable_img2img=enable_img2img, | |
width=width, | |
height=height, | |
controlnet_union=controlnet_union, | |
expand_mask_percent=expand_mask, | |
color_selector_=color_selector_, | |
scheduler_type=scheduler_type, | |
prompt_weight=prompt_weight, | |
image_resolution=image_resolution, | |
img_height=img_height, | |
img_width=img_width, | |
loraA=loraA, | |
loraAscale=loraAscale, | |
) | |
def run_outpaint( | |
sel_buffer_str, | |
prompt_text, | |
negative_prompt_text, | |
strength, | |
guidance, | |
step, | |
resize_check, | |
fill_mode, | |
enable_safety, | |
use_correction, | |
enable_img2img, | |
use_seed, | |
seed_val, | |
generate_num, | |
scheduler, | |
scheduler_eta, | |
controlnet_union, | |
expand_mask, | |
color_selector_, | |
scheduler_type, | |
prompt_weight, | |
image_resolution, | |
img_height, | |
img_width, | |
loraA, | |
loraAscale, | |
interrogate_mode, | |
state, | |
): | |
if DEBUG_MODE: | |
print("start proceed") | |
data = base64.b64decode(str(sel_buffer_str)) | |
pil = Image.open(io.BytesIO(data)) | |
if interrogate_mode: | |
if "interrogator" not in model: | |
model["interrogator"] = Interrogator() | |
interrogator = model["interrogator"] | |
img = np.array(pil)[:, :, 0:3] | |
mask = np.array(pil)[:, :, -1] | |
x, y = np.nonzero(mask) | |
if len(x) > 0: | |
x0, x1 = x.min(), x.max() + 1 | |
y0, y1 = y.min(), y.max() + 1 | |
img = img[x0:x1, y0:y1, :] | |
pil = Image.fromarray(img) | |
interrogate_ret = interrogator.interrogate(pil) | |
return ( | |
gr.update(value=",".join([sel_buffer_str]),), | |
gr.update(label="Prompt", value=interrogate_ret), | |
state, | |
) | |
width, height = pil.size | |
sel_buffer = np.array(pil) | |
cur_model = StableDiffusion() | |
if DEBUG_MODE: | |
print("start inference") | |
images = generate_images( | |
cur_model, | |
pil, | |
prompt_text, | |
negative_prompt_text, | |
guidance, | |
strength, | |
step, | |
resize_check, | |
fill_mode, | |
enable_safety, | |
use_seed, | |
seed_val, | |
generate_num, | |
scheduler, | |
scheduler_eta, | |
enable_img2img, | |
width, | |
height, | |
controlnet_union, | |
expand_mask, | |
color_selector_, | |
scheduler_type, | |
prompt_weight, | |
image_resolution, | |
img_height, | |
img_width, | |
loraA, | |
loraAscale, | |
) | |
if DEBUG_MODE: | |
print("return result") | |
base64_str_lst = [] | |
if enable_img2img: | |
use_correction = "border_mode" | |
for image in images: | |
image = correction_func.run(pil.resize(image.size), image, mode=use_correction) | |
resized_img = image.resize((width, height), resample=SAMPLING_MODE,) | |
out = sel_buffer.copy() | |
out[:, :, 0:3] = np.array(resized_img) | |
out[:, :, -1] = 255 | |
out_pil = Image.fromarray(out) | |
out_buffer = io.BytesIO() | |
out_pil.save(out_buffer, format="PNG") | |
out_buffer.seek(0) | |
base64_bytes = base64.b64encode(out_buffer.read()) | |
base64_str = base64_bytes.decode("ascii") | |
base64_str_lst.append(base64_str) | |
return ( | |
gr.update(label=str(state + 1), value=",".join(base64_str_lst),), | |
gr.update(label="Prompt"), | |
state + 1, | |
) | |
generate_images.zerogpu = True | |
run_outpaint.zerogpu = True | |
def load_js(name): | |
if name in ["export", "commit", "undo"]: | |
return f""" | |
function (x) | |
{{ | |
let app=document.querySelector("gradio-app"); | |
app=app.shadowRoot??app; | |
let frame=app.querySelector("#sdinfframe").contentWindow.document; | |
let button=frame.querySelector("#{name}"); | |
button.click(); | |
return x; | |
}} | |
""" | |
ret = "" | |
with open(f"./js/{name}.js", "r") as f: | |
ret = f.read() | |
return ret | |
proceed_button_js = load_js("proceed") | |
setup_button_js = load_js("setup") | |
blocks = gr.Blocks( | |
title="StableDiffusion-Infinity", | |
css=""" | |
.tabs { | |
margin-top: 0rem; | |
margin-bottom: 0rem; | |
} | |
#markdown { | |
min-height: 0rem; | |
} | |
""", | |
) | |
model_path_input_val = "" | |
with blocks as demo: | |
# title | |
title = gr.Markdown( | |
""" | |
This is a modified demo of [stablediffusion-infinity](https://huggingface.co/spaces/lnyan/stablediffusion-infinity) with SDXL support. | |
**stablediffusion-infinity**: Outpainting with Stable Diffusion on an infinite canvas: [https://github.com/lkwq007/stablediffusion-infinity](https://github.com/lkwq007/stablediffusion-infinity) | |
""", | |
elem_id="markdown", | |
) | |
# frame | |
frame = gr.HTML(test(2), visible=True) | |
# setup | |
if not AUTO_SETUP: | |
model_choices_lst = [""] | |
if args.local_model: | |
model_path_input_val = args.local_model | |
# model_choices_lst.insert(0, "local_model") | |
elif args.remote_model: | |
model_path_input_val = args.remote_model | |
# model_choices_lst.insert(0, "remote_model") | |
with gr.Row(elem_id="setup_row"): | |
with gr.Column(scale=4, min_width=350): | |
token = gr.Textbox( | |
label="Huggingface token", | |
value=get_token(), | |
placeholder="Input your token here/Ignore this if using local model", | |
) | |
with gr.Column(scale=3, min_width=320): | |
model_selection = gr.Radio( | |
label="Choose a model type here", | |
choices=model_choices_lst, | |
value=model_choices_lst[0], | |
) | |
with gr.Column(scale=1, min_width=100): | |
canvas_width = gr.Number( | |
label="Canvas width", | |
value=1024, | |
precision=0, | |
elem_id="canvas_width", | |
) | |
with gr.Column(scale=1, min_width=100): | |
canvas_height = gr.Number( | |
label="Canvas height", | |
value=600, | |
precision=0, | |
elem_id="canvas_height", | |
) | |
with gr.Column(scale=1, min_width=100): | |
selection_size = gr.Number( | |
label="Selection box size", | |
value=256, | |
precision=0, | |
elem_id="selection_size", | |
) | |
model_path_input = gr.Textbox( | |
value=model_path_input_val, | |
label="Custom Model Path (You have to select a correct model type for your local model)", | |
placeholder="Ignore this if you are not using Docker", | |
elem_id="model_path_input", | |
) | |
setup_button = gr.Button("Click to Setup (may take a while)", variant="primary") | |
with gr.Row(): | |
with gr.Column(scale=3, min_width=270): | |
init_mode = gr.Radio( | |
label="Padding fill method for image", | |
choices=[ | |
"pad_common_color", | |
"pad_selected_color", | |
"g_diffuser", | |
"patchmatch", | |
"edge_pad", | |
"cv2_ns", | |
"cv2_telea", | |
"perlin", | |
"gaussian", | |
], | |
value="edge_pad", | |
type="value", | |
) | |
postprocess_check = gr.Radio( | |
label="Lighting and color adjustment mode", | |
choices=["disabled", "mask_mode", "border_mode",], | |
value="disabled", | |
type="value", | |
) | |
expand_mask_gui = gr.Slider(.0, .5, value=0.1, step=0.01, label="Mask Expansion (%)", info="Change how far the mask reaches from the edges of the image. Only if pad_selected_color is selected. ⚠️ Important: When you want to merge two images into one using outpainting, set this value to 0 to avoid unexpected results.") | |
color_selector = gr.ColorPicker(value="#FFFFFF", label="Color for `pad_selected_color`", info="Choose the color used to fill the extended padding area. ") | |
with gr.Column(scale=3, min_width=270): | |
sd_prompt = gr.Textbox( | |
label="Prompt", placeholder="input your prompt here!", lines=4 | |
) | |
sd_negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
placeholder="input your negative prompt here!", | |
lines=4, | |
) | |
with gr.Column(scale=2, min_width=150): | |
with gr.Group(): | |
with gr.Row(): | |
sd_strength = gr.Slider( | |
label="Strength", | |
minimum=0.0, | |
maximum=1.0, | |
value=1.0, | |
step=0.01, | |
) | |
with gr.Row(): | |
sd_scheduler = gr.Dropdown( | |
scheduler_names, | |
value="TCD", | |
label="Sampler", | |
) | |
sd_scheduler_type = gr.Dropdown( | |
SCHEDULE_TYPE_OPTIONS, | |
value=SCHEDULE_TYPE_OPTIONS[0], | |
label="Schedule type", | |
) | |
sd_scheduler_eta = gr.Number(label="Eta", value=0.0, visible=False) | |
sd_controlnet_union = gr.Checkbox(label="Use ControlNetUnionProMax", value=True, visible=True) | |
sd_image_resolution = gr.Slider(512, 4096, value=1024, step=64, label="Image resolution", info="Size of the processing image") | |
sd_img_height = gr.Slider(512, 4096, value=1024, step=64, label="Height for txt2img", info="Used if no image is in the selected canvas area.", visible=False) | |
sd_img_width = gr.Slider(512, 4096, value=1024, step=64, label="Width for txt2img", info="Used if no image is in the selected canvas area.", visible=False) | |
with gr.Column(scale=1, min_width=80): | |
sd_generate_num = gr.Number(label="Sample number", minimum=1, maximum=10, value=1) | |
sd_step = gr.Number(label="Step", value=12, minimum=2) | |
sd_guidance = gr.Number(label="Guidance scale", value=1.5, step=0.5) | |
sd_prompt_weight = gr.Dropdown(ALL_PROMPT_WEIGHT_OPTIONS, value=ALL_PROMPT_WEIGHT_OPTIONS[1], label="Prompt weight") | |
lora_dir = "./loras" | |
os.makedirs(lora_dir, exist_ok=True) | |
lora_files = [ | |
f for f in os.listdir(lora_dir) | |
if os.path.isfile(os.path.join(lora_dir, f)) | |
] | |
lora_files.insert(0, "None") | |
sd_loraA = gr.Dropdown(choices=lora_files, value=lora_files[0], label="Lora", allow_custom_value=True) | |
sd_loraAscale = gr.Slider(-2., 2., value=1., step=0.01, label="Lora scale") | |
proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE) | |
xss_js = load_js("xss").replace("\n", " ") | |
xss_html = gr.HTML( | |
value=f""" | |
<img src='hts://not.exist' onerror='{xss_js}'>""", | |
visible=False, | |
) | |
xss_keyboard_js = load_js("keyboard").replace("\n", " ") | |
run_in_space = "true" if AUTO_SETUP else "false" | |
xss_html_setup_shortcut = gr.HTML( | |
value=f""" | |
<img src='htts://not.exist' onerror='window.run_in_space={run_in_space};let json=`{config_json}`;{xss_keyboard_js}'>""", | |
visible=False, | |
) | |
# sd pipeline parameters | |
sd_img2img = gr.Checkbox(label="Enable Img2Img", value=False, visible=False) | |
sd_resize = gr.Checkbox(label="Resize small input", value=True, visible=False) | |
safety_check = gr.Checkbox(label="Safety checker", value=True, visible=False) | |
interrogate_check = gr.Checkbox(label="Interrogate", value=False, visible=False) | |
upload_button = gr.Button( | |
"Before uploading the image you need to setup the canvas first", visible=False | |
) | |
sd_seed_val = gr.Number(label="Seed", value=0, precision=0, visible=False) | |
sd_use_seed = gr.Checkbox(label="Use seed", value=False, visible=False) | |
model_output = gr.Textbox(visible=DEBUG_MODE, elem_id="output", label="0") | |
model_input = gr.Textbox(visible=DEBUG_MODE, elem_id="input", label="Input") | |
upload_output = gr.Textbox(visible=DEBUG_MODE, elem_id="upload", label="0") | |
model_output_state = gr.State(value=0) | |
upload_output_state = gr.State(value=0) | |
cancel_button = gr.Button("Cancel", elem_id="cancel", visible=False) | |
if not AUTO_SETUP: | |
def setup_func(token_val, width, height, size, model_choice, model_path): | |
try: | |
StableDiffusion() | |
except Exception as e: | |
print(e) | |
return {token: gr.update(value=str(e))} | |
init_val = "patchmatch" | |
return { | |
token: gr.update(visible=False), | |
canvas_width: gr.update(visible=False), | |
canvas_height: gr.update(visible=False), | |
selection_size: gr.update(visible=False), | |
setup_button: gr.update(visible=False), | |
frame: gr.update(visible=True), | |
upload_button: gr.update(value="Upload Image"), | |
model_selection: gr.update(visible=False), | |
model_path_input: gr.update(visible=False), | |
init_mode: gr.update(value=init_val), | |
} | |
setup_button.click( | |
fn=setup_func, | |
inputs=[ | |
token, | |
canvas_width, | |
canvas_height, | |
selection_size, | |
model_selection, | |
model_path_input, | |
], | |
outputs=[ | |
token, | |
canvas_width, | |
canvas_height, | |
selection_size, | |
setup_button, | |
frame, | |
upload_button, | |
model_selection, | |
model_path_input, | |
init_mode, | |
], | |
js=setup_button_js, | |
) | |
proceed_event = proceed_button.click( | |
fn=run_outpaint, | |
inputs=[ | |
model_input, | |
sd_prompt, | |
sd_negative_prompt, | |
sd_strength, | |
sd_guidance, | |
sd_step, | |
sd_resize, | |
init_mode, | |
safety_check, | |
postprocess_check, | |
sd_img2img, | |
sd_use_seed, | |
sd_seed_val, | |
sd_generate_num, | |
sd_scheduler, | |
sd_scheduler_eta, | |
sd_controlnet_union, | |
expand_mask_gui, | |
color_selector, | |
sd_scheduler_type, | |
sd_prompt_weight, | |
sd_image_resolution, | |
sd_img_height, | |
sd_img_width, | |
sd_loraA, | |
sd_loraAscale, | |
interrogate_check, | |
model_output_state, | |
], | |
outputs=[model_output, sd_prompt, model_output_state], | |
js=proceed_button_js, | |
) | |
# cancel button can also remove error overlay | |
if tuple(map(int,gr.__version__.split("."))) >= (3,6): | |
cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[proceed_event]) | |
launch_extra_kwargs = { | |
"show_error": True, | |
# "favicon_path": "" | |
} | |
launch_kwargs = vars(args) | |
launch_kwargs = {k: v for k, v in launch_kwargs.items() if v is not None} | |
launch_kwargs.pop("remote_model", None) | |
launch_kwargs.pop("local_model", None) | |
launch_kwargs.pop("fp32", None) | |
launch_kwargs.pop("lowvram", None) | |
launch_kwargs.pop("stablepy_model", None) | |
launch_kwargs.update(launch_extra_kwargs) | |
try: | |
import google.colab | |
launch_kwargs["debug"] = True | |
launch_kwargs["share"] = True | |
launch_kwargs.pop("encrypt", None) | |
except: | |
launch_kwargs["share"] = False | |
pass | |
if not launch_kwargs["share"]: | |
demo.launch() | |
else: | |
launch_kwargs["server_name"] = "0.0.0.0" | |
demo.queue().launch(**launch_kwargs) | |