1lint
init commit
6230dda
raw
history blame
8.4 kB
import gradio as gr
from multiprocessing import cpu_count
from src.ui_shared import (
model_ids,
scheduler_names,
default_scheduler,
controlnet_ids,
assets_directory,
)
from src.ui_functions import generate, run_training
default_img_size = 512
with open(f"{assets_directory}/header.html") as fp:
header = fp.read()
with open(f"{assets_directory}/footer.html") as fp:
footer = fp.read()
theme = gr.themes.Soft(
primary_hue="blue",
neutral_hue="slate",
)
from gradio.themes.builder_app import css
with gr.Blocks(theme=theme) as demo:
gr.HTML(header)
with gr.Row():
with gr.Column(scale=70):
prompt = gr.Textbox(
label="Prompt", placeholder="Press <Shift+Enter> to generate", lines=2
)
neg_prompt = gr.Textbox(label="Negative Prompt", placeholder="", lines=2)
with gr.Row():
controlnet_prompt = gr.Textbox(
label="Controlnet Prompt",
placeholder="If empty, defaults to base `Prompt`",
lines=2,
)
controlnet_negative_prompt = gr.Textbox(
label="Controlnet Negative Prompt",
placeholder="If empty, defaults to base `Negative Prompt`",
lines=2,
)
with gr.Column(scale=30):
model_name = gr.Dropdown(
label="Model", choices=model_ids, value=model_ids[0]
)
controlnet_name = gr.Dropdown(
label="Controlnet", choices=controlnet_ids, value=controlnet_ids[0]
)
scheduler_name = gr.Dropdown(
label="Scheduler", choices=scheduler_names, value=default_scheduler
)
generate_button = gr.Button(value="Generate", variant="primary")
with gr.Row():
with gr.Column():
with gr.Tab("Inference") as tab:
guidance_image = gr.Image(
label="Guidance Image",
source="upload",
tool="editor",
type="pil",
).style(height=256)
with gr.Row():
controlnet_cond_scale = gr.Slider(
label="Controlnet Weight",
value=0.5,
minimum=0.0,
maximum=1.0,
step=0.1,
)
with gr.Row():
batch_size = gr.Slider(
label="Batch Size", value=1, minimum=1, maximum=8, step=1
)
seed = gr.Slider(-1, 2147483647, label="Seed", value=-1, step=1)
with gr.Row():
guidance = gr.Slider(
label="Guidance scale", value=7.5, minimum=0, maximum=20
)
steps = gr.Slider(
label="Steps", value=20, minimum=1, maximum=100, step=1
)
with gr.Row():
width = gr.Slider(
label="Width",
value=default_img_size,
minimum=64,
maximum=1024,
step=32,
)
height = gr.Slider(
label="Height",
value=default_img_size,
minimum=64,
maximum=1024,
step=32,
)
with gr.Tab("Train Style ControlNet") as tab:
with gr.Row():
train_batch_size = gr.Slider(
label="Training Batch Size",
minimum=1,
maximum=8,
step=1,
value=1,
)
gradient_accumulation_steps = gr.Slider(
label="Gradient Accumulation steps",
minimum=1,
maximum=6,
step=1,
value=4,
)
with gr.Row():
max_train_steps = gr.Number(
label="Total training steps", value=16000
)
train_learning_rate = gr.Number(label="Learning Rate", value=5.0e-6)
with gr.Row():
checkpointing_steps = gr.Number(
label="Steps between saving checkpoints", value=4000
)
image_logging_steps = gr.Number(
label="Steps between logging example images (pass 0 to disable)",
value=0,
)
with gr.Row():
train_data_dir = gr.Textbox(
label=f"Path to training image folder",
value="lint/anybooru",
)
valid_data_dir = gr.Textbox(
label=f"Path to validation image folder",
value="",
)
with gr.Row():
controlnet_weights_path = gr.Textbox(
label=f"Repo for initializing Controlnet Weights",
value="andite/anything-v4.0/unet",
)
output_dir = gr.Textbox(
label=f"Output directory for trained weights", value="./models"
)
with gr.Row():
train_whole_controlnet = gr.Checkbox(
label="Train whole controlnet", value=True
)
save_whole_pipeline = gr.Checkbox(
label="Save whole pipeline", value=True
)
training_button = gr.Button(
value="Train Style ControlNet", variant="primary"
)
training_status = gr.Text(label="Training Status")
with gr.Column():
gallery = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
).style(height=default_img_size, grid=2)
generation_details = gr.Markdown()
# pipe_kwargs = gr.Textbox(label="Pipe kwargs", value="{\n\t\n}", visible=False)
# if torch.cuda.is_available():
# giga = 2**30
# vram_guage = gr.Slider(0, torch.cuda.memory_reserved(0)/giga, label='VRAM Allocated to Reserved (GB)', value=0, step=1)
# demo.load(lambda : torch.cuda.memory_allocated(0)/giga, inputs=[], outputs=vram_guage, every=0.5, show_progress=False)
# gr.HTML(footer)
inputs = [
model_name,
guidance_image,
controlnet_name,
scheduler_name,
prompt,
guidance,
steps,
batch_size,
width,
height,
seed,
neg_prompt,
controlnet_prompt,
controlnet_negative_prompt,
controlnet_cond_scale,
# pipe_kwargs,
]
outputs = [gallery, generation_details]
prompt.submit(generate, inputs=inputs, outputs=outputs)
generate_button.click(generate, inputs=inputs, outputs=outputs)
training_inputs = [
model_name,
controlnet_weights_path,
train_data_dir,
valid_data_dir,
train_batch_size,
train_whole_controlnet,
gradient_accumulation_steps,
max_train_steps,
train_learning_rate,
output_dir,
checkpointing_steps,
image_logging_steps,
save_whole_pipeline,
]
training_button.click(
run_training,
inputs=training_inputs,
outputs=[training_status],
)
# from gradio.themes.builder_app
demo.load(
None,
None,
None,
_js="""() => {
if (document.querySelectorAll('.dark').length) {
document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
} else {
document.querySelector('body').classList.add('dark');
}
}""",
)
if __name__ == "__main__":
demo.queue(concurrency_count=cpu_count()).launch()