Spaces:
Build error
Build error
import gradio as gr | |
import os | |
import argparse | |
from kohya_gui.class_gui_config import KohyaSSGUIConfig | |
from kohya_gui.dreambooth_gui import dreambooth_tab | |
from kohya_gui.finetune_gui import finetune_tab | |
from kohya_gui.textual_inversion_gui import ti_tab | |
from kohya_gui.utilities import utilities_tab | |
from kohya_gui.lora_gui import lora_tab | |
from kohya_gui.class_lora_tab import LoRATools | |
from kohya_gui.custom_logging import setup_logging | |
from kohya_gui.localization_ext import add_javascript | |
def UI(**kwargs): | |
add_javascript(kwargs.get("language")) | |
css = "" | |
headless = kwargs.get("headless", False) | |
log.info(f"headless: {headless}") | |
if os.path.exists("./assets/style.css"): | |
with open(os.path.join("./assets/style.css"), "r", encoding="utf8") as file: | |
log.debug("Load CSS...") | |
css += file.read() + "\n" | |
if os.path.exists("./.release"): | |
with open(os.path.join("./.release"), "r", encoding="utf8") as file: | |
release = file.read() | |
if os.path.exists("./README.md"): | |
with open(os.path.join("./README.md"), "r", encoding="utf8") as file: | |
README = file.read() | |
interface = gr.Blocks( | |
css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default() | |
) | |
config = KohyaSSGUIConfig(config_file_path=kwargs.get("config")) | |
if config.is_config_loaded(): | |
log.info(f"Loaded default GUI values from '{kwargs.get('config')}'...") | |
use_shell_flag = True | |
# if os.name == "posix": | |
# use_shell_flag = True | |
use_shell_flag = config.get("settings.use_shell", use_shell_flag) | |
if kwargs.get("do_not_use_shell", False): | |
use_shell_flag = False | |
if use_shell_flag: | |
log.info("Using shell=True when running external commands...") | |
with interface: | |
with gr.Tab("Dreambooth"): | |
( | |
train_data_dir_input, | |
reg_data_dir_input, | |
output_dir_input, | |
logging_dir_input, | |
) = dreambooth_tab( | |
headless=headless, config=config, use_shell_flag=use_shell_flag | |
) | |
with gr.Tab("LoRA"): | |
lora_tab(headless=headless, config=config, use_shell_flag=use_shell_flag) | |
with gr.Tab("Textual Inversion"): | |
ti_tab(headless=headless, config=config, use_shell_flag=use_shell_flag) | |
with gr.Tab("Finetuning"): | |
finetune_tab( | |
headless=headless, config=config, use_shell_flag=use_shell_flag | |
) | |
with gr.Tab("Utilities"): | |
utilities_tab( | |
train_data_dir_input=train_data_dir_input, | |
reg_data_dir_input=reg_data_dir_input, | |
output_dir_input=output_dir_input, | |
logging_dir_input=logging_dir_input, | |
headless=headless, | |
config=config, | |
) | |
with gr.Tab("LoRA"): | |
_ = LoRATools(headless=headless) | |
with gr.Tab("About"): | |
gr.Markdown(f"kohya_ss GUI release {release}") | |
with gr.Tab("README"): | |
gr.Markdown(README) | |
htmlStr = f""" | |
<html> | |
<body> | |
<div class="ver-class">{release}</div> | |
</body> | |
</html> | |
""" | |
gr.HTML(htmlStr) | |
# Show the interface | |
launch_kwargs = {} | |
username = kwargs.get("username") | |
password = kwargs.get("password") | |
server_port = kwargs.get("server_port", 0) | |
inbrowser = kwargs.get("inbrowser", False) | |
share = kwargs.get("share", False) | |
do_not_share = kwargs.get("do_not_share", False) | |
server_name = kwargs.get("listen") | |
root_path = kwargs.get("root_path", None) | |
launch_kwargs["server_name"] = server_name | |
if username and password: | |
launch_kwargs["auth"] = (username, password) | |
if server_port > 0: | |
launch_kwargs["server_port"] = server_port | |
if inbrowser: | |
launch_kwargs["inbrowser"] = inbrowser | |
if do_not_share: | |
launch_kwargs["share"] = False | |
else: | |
if share: | |
launch_kwargs["share"] = share | |
if root_path: | |
launch_kwargs["root_path"] = root_path | |
launch_kwargs["debug"] = True | |
interface.launch(**launch_kwargs) | |
if __name__ == "__main__": | |
# torch.cuda.set_per_process_memory_fraction(0.48) | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--config", | |
type=str, | |
default="./config.toml", | |
help="Path to the toml config file for interface defaults", | |
) | |
parser.add_argument("--debug", action="store_true", help="Debug on") | |
parser.add_argument( | |
"--listen", | |
type=str, | |
default="127.0.0.1", | |
help="IP to listen on for connections to Gradio", | |
) | |
parser.add_argument( | |
"--username", type=str, default="", help="Username for authentication" | |
) | |
parser.add_argument( | |
"--password", type=str, default="", help="Password for authentication" | |
) | |
parser.add_argument( | |
"--server_port", | |
type=int, | |
default=0, | |
help="Port to run the server listener on", | |
) | |
parser.add_argument("--inbrowser", action="store_true", help="Open in browser") | |
parser.add_argument("--share", action="store_true", help="Share the gradio UI") | |
parser.add_argument( | |
"--headless", action="store_true", help="Is the server headless" | |
) | |
parser.add_argument( | |
"--language", type=str, default=None, help="Set custom language" | |
) | |
parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment") | |
parser.add_argument("--use-rocm", action="store_true", help="Use ROCm environment") | |
parser.add_argument( | |
"--do_not_use_shell", action="store_true", help="Enforce not to use shell=True when running external commands" | |
) | |
parser.add_argument( | |
"--do_not_share", action="store_true", help="Do not share the gradio UI" | |
) | |
parser.add_argument( | |
"--root_path", type=str, default=None, help="`root_path` for Gradio to enable reverse proxy support. e.g. /kohya_ss" | |
) | |
args = parser.parse_args() | |
# Set up logging | |
log = setup_logging(debug=args.debug) | |
UI(**vars(args)) | |