Spaces:
Build error
Build error
| import gradio as gr | |
| import os | |
| import argparse | |
| from kohya_gui.class_gui_config import KohyaSSGUIConfig | |
| from kohya_gui.dreambooth_gui import dreambooth_tab | |
| from kohya_gui.finetune_gui import finetune_tab | |
| from kohya_gui.textual_inversion_gui import ti_tab | |
| from kohya_gui.utilities import utilities_tab | |
| from kohya_gui.lora_gui import lora_tab | |
| from kohya_gui.class_lora_tab import LoRATools | |
| from kohya_gui.custom_logging import setup_logging | |
| from kohya_gui.localization_ext import add_javascript | |
| def UI(**kwargs): | |
| add_javascript(kwargs.get("language")) | |
| css = "" | |
| headless = kwargs.get("headless", False) | |
| log.info(f"headless: {headless}") | |
| if os.path.exists("./assets/style.css"): | |
| with open(os.path.join("./assets/style.css"), "r", encoding="utf8") as file: | |
| log.debug("Load CSS...") | |
| css += file.read() + "\n" | |
| if os.path.exists("./.release"): | |
| with open(os.path.join("./.release"), "r", encoding="utf8") as file: | |
| release = file.read() | |
| if os.path.exists("./README.md"): | |
| with open(os.path.join("./README.md"), "r", encoding="utf8") as file: | |
| README = file.read() | |
| interface = gr.Blocks( | |
| css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default() | |
| ) | |
| config = KohyaSSGUIConfig(config_file_path=kwargs.get("config")) | |
| if config.is_config_loaded(): | |
| log.info(f"Loaded default GUI values from '{kwargs.get('config')}'...") | |
| use_shell_flag = True | |
| # if os.name == "posix": | |
| # use_shell_flag = True | |
| use_shell_flag = config.get("settings.use_shell", use_shell_flag) | |
| if kwargs.get("do_not_use_shell", False): | |
| use_shell_flag = False | |
| if use_shell_flag: | |
| log.info("Using shell=True when running external commands...") | |
| with interface: | |
| with gr.Tab("Dreambooth"): | |
| ( | |
| train_data_dir_input, | |
| reg_data_dir_input, | |
| output_dir_input, | |
| logging_dir_input, | |
| ) = dreambooth_tab( | |
| headless=headless, config=config, use_shell_flag=use_shell_flag | |
| ) | |
| with gr.Tab("LoRA"): | |
| lora_tab(headless=headless, config=config, use_shell_flag=use_shell_flag) | |
| with gr.Tab("Textual Inversion"): | |
| ti_tab(headless=headless, config=config, use_shell_flag=use_shell_flag) | |
| with gr.Tab("Finetuning"): | |
| finetune_tab( | |
| headless=headless, config=config, use_shell_flag=use_shell_flag | |
| ) | |
| with gr.Tab("Utilities"): | |
| utilities_tab( | |
| train_data_dir_input=train_data_dir_input, | |
| reg_data_dir_input=reg_data_dir_input, | |
| output_dir_input=output_dir_input, | |
| logging_dir_input=logging_dir_input, | |
| headless=headless, | |
| config=config, | |
| ) | |
| with gr.Tab("LoRA"): | |
| _ = LoRATools(headless=headless) | |
| with gr.Tab("About"): | |
| gr.Markdown(f"kohya_ss GUI release {release}") | |
| with gr.Tab("README"): | |
| gr.Markdown(README) | |
| htmlStr = f""" | |
| <html> | |
| <body> | |
| <div class="ver-class">{release}</div> | |
| </body> | |
| </html> | |
| """ | |
| gr.HTML(htmlStr) | |
| # Show the interface | |
| launch_kwargs = {} | |
| username = kwargs.get("username") | |
| password = kwargs.get("password") | |
| server_port = kwargs.get("server_port", 0) | |
| inbrowser = kwargs.get("inbrowser", False) | |
| share = kwargs.get("share", False) | |
| do_not_share = kwargs.get("do_not_share", False) | |
| server_name = kwargs.get("listen") | |
| root_path = kwargs.get("root_path", None) | |
| launch_kwargs["server_name"] = server_name | |
| if username and password: | |
| launch_kwargs["auth"] = (username, password) | |
| if server_port > 0: | |
| launch_kwargs["server_port"] = server_port | |
| if inbrowser: | |
| launch_kwargs["inbrowser"] = inbrowser | |
| if do_not_share: | |
| launch_kwargs["share"] = False | |
| else: | |
| if share: | |
| launch_kwargs["share"] = share | |
| if root_path: | |
| launch_kwargs["root_path"] = root_path | |
| launch_kwargs["debug"] = True | |
| interface.launch(**launch_kwargs) | |
| if __name__ == "__main__": | |
| # torch.cuda.set_per_process_memory_fraction(0.48) | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--config", | |
| type=str, | |
| default="./config.toml", | |
| help="Path to the toml config file for interface defaults", | |
| ) | |
| parser.add_argument("--debug", action="store_true", help="Debug on") | |
| parser.add_argument( | |
| "--listen", | |
| type=str, | |
| default="127.0.0.1", | |
| help="IP to listen on for connections to Gradio", | |
| ) | |
| parser.add_argument( | |
| "--username", type=str, default="", help="Username for authentication" | |
| ) | |
| parser.add_argument( | |
| "--password", type=str, default="", help="Password for authentication" | |
| ) | |
| parser.add_argument( | |
| "--server_port", | |
| type=int, | |
| default=0, | |
| help="Port to run the server listener on", | |
| ) | |
| parser.add_argument("--inbrowser", action="store_true", help="Open in browser") | |
| parser.add_argument("--share", action="store_true", help="Share the gradio UI") | |
| parser.add_argument( | |
| "--headless", action="store_true", help="Is the server headless" | |
| ) | |
| parser.add_argument( | |
| "--language", type=str, default=None, help="Set custom language" | |
| ) | |
| parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment") | |
| parser.add_argument("--use-rocm", action="store_true", help="Use ROCm environment") | |
| parser.add_argument( | |
| "--do_not_use_shell", action="store_true", help="Enforce not to use shell=True when running external commands" | |
| ) | |
| parser.add_argument( | |
| "--do_not_share", action="store_true", help="Do not share the gradio UI" | |
| ) | |
| parser.add_argument( | |
| "--root_path", type=str, default=None, help="`root_path` for Gradio to enable reverse proxy support. e.g. /kohya_ss" | |
| ) | |
| args = parser.parse_args() | |
| # Set up logging | |
| log = setup_logging(debug=args.debug) | |
| UI(**vars(args)) | |