|
import os |
|
import random |
|
from functools import partial |
|
|
|
if os.environ.get("IN_SPACES", None) is not None: |
|
in_spaces = True |
|
import spaces |
|
else: |
|
in_spaces = False |
|
import gradio as gr |
|
import torch |
|
|
|
try: |
|
|
|
import triton |
|
except ImportError: |
|
print("Triton not found, skip pre import") |
|
|
|
|
|
import xut.env |
|
xut.env.TORCH_COMPILE = True |
|
xut.env.USE_LIGER = False |
|
xut.env.USE_VANILLA = False |
|
xut.env.USE_XFORMERS = True |
|
xut.env.USE_XFORMERS_LAYERS = True |
|
from hdm.pipeline import HDMXUTPipeline |
|
|
|
|
|
import kgen.models as kgen_models |
|
import kgen.executor.tipo as tipo |
|
from kgen.formatter import apply_format, seperate_tags |
|
|
|
|
|
torch.set_float32_matmul_precision("high") |
|
|
|
|
|
DEFAULT_FORMAT = """ |
|
<|special|>, |
|
<|characters|>, <|copyrights|>, |
|
<|artist|>, |
|
|
|
<|general|>, |
|
|
|
<|extended|>. |
|
|
|
<|quality|>, <|meta|>, <|rating|> |
|
""".strip() |
|
|
|
|
|
def GPU(func=None, duration=None): |
|
if func is None: |
|
return partial(GPU, duration=duration) |
|
if in_spaces: |
|
if duration: |
|
return spaces.GPU(func, duration=duration) |
|
else: |
|
return spaces.GPU(func) |
|
else: |
|
return func |
|
|
|
|
|
def prompt_opt(tags, nl_prompt, aspect_ratio, seed): |
|
meta, operations, general, nl_prompt = tipo.parse_tipo_request( |
|
seperate_tags(tags.split(",")), |
|
nl_prompt, |
|
tag_length_target="long", |
|
nl_length_target="short", |
|
generate_extra_nl_prompt=True, |
|
) |
|
meta["aspect_ratio"] = f"{aspect_ratio:.3f}" |
|
result, timing = tipo.tipo_runner(meta, operations, general, nl_prompt, seed=seed) |
|
return apply_format(result, DEFAULT_FORMAT).strip().strip(".").strip(",") |
|
|
|
|
|
print("Loading models, please wait...") |
|
device = torch.device("cuda") |
|
|
|
model = ( |
|
HDMXUTPipeline.from_pretrained( |
|
"KBlueLeaf/HDM-xut-340M-anime", |
|
trust_remote_code=True, |
|
) |
|
.to(torch.float16) |
|
.to(device) |
|
) |
|
|
|
tipo_model_name, gguf_list = kgen_models.tipo_model_list[0] |
|
kgen_models.load_model(tipo_model_name, device="cuda") |
|
print("Models loaded successfully. UI is ready.") |
|
|
|
|
|
@GPU(duration=10) |
|
@torch.no_grad() |
|
def generate( |
|
nl_prompt: str, |
|
tag_prompt: str, |
|
negative_prompt: str, |
|
tipo_enable: bool, |
|
format_enable: bool, |
|
num_images: int, |
|
steps: int, |
|
cfg_scale: float, |
|
size: int, |
|
aspect_ratio: str, |
|
fixed_short_edge: bool, |
|
zoom: float, |
|
x_shift: float, |
|
y_shift: float, |
|
tread_gamma1: float, |
|
tread_gamma2: float, |
|
seed: int, |
|
progress=gr.Progress(), |
|
): |
|
as_w, as_h = aspect_ratio.split(":") |
|
aspect_ratio = float(as_w) / float(as_h) |
|
|
|
if seed == -1: |
|
seed = random.randint(0, 2**32 - 1) |
|
torch.manual_seed(seed) |
|
|
|
|
|
if tipo_enable: |
|
tipo.BAN_TAGS = [i.strip() for i in negative_prompt.split(",") if i.strip()] |
|
final_prompt = prompt_opt(tag_prompt, nl_prompt, aspect_ratio, seed) |
|
elif format_enable: |
|
final_prompt = apply_format(nl_prompt, DEFAULT_FORMAT) |
|
else: |
|
final_prompt = tag_prompt + "\n" + nl_prompt |
|
|
|
yield None, final_prompt |
|
|
|
prompts_to_generate = [final_prompt.replace("\n", " ")] * num_images |
|
negative_prompts_to_generate = [negative_prompt] * num_images |
|
|
|
if fixed_short_edge: |
|
if aspect_ratio > 1: |
|
h_factor = 1 |
|
w_factor = aspect_ratio |
|
else: |
|
h_factor = 1 / aspect_ratio |
|
w_factor = 1 |
|
else: |
|
w_factor = aspect_ratio**0.5 |
|
h_factor = 1 / w_factor |
|
|
|
w = int(size * w_factor / 16) * 16 |
|
h = int(size * h_factor / 16) * 16 |
|
|
|
print("=" * 100) |
|
print( |
|
f"Generating {num_images} image(s) with seed: {seed} and resolution {w}x{h}" |
|
) |
|
print("-" * 80) |
|
print(f"Final prompt: {final_prompt}") |
|
print("-" * 80) |
|
print(f"Negative prompt: {negative_prompt}") |
|
print("-" * 80) |
|
|
|
prompts_batch = prompts_to_generate |
|
neg_prompts_batch = negative_prompts_to_generate |
|
|
|
images = model( |
|
prompts_batch, |
|
neg_prompts_batch, |
|
num_inference_steps=steps, |
|
cfg_scale=cfg_scale, |
|
width=w, |
|
height=h, |
|
camera_param={ |
|
"zoom": zoom, |
|
"x_shift": x_shift, |
|
"y_shift": y_shift, |
|
}, |
|
tread_gamma1=tread_gamma1, |
|
tread_gamma2=tread_gamma2, |
|
).images |
|
|
|
yield images, final_prompt |
|
|
|
|
|
|
|
with gr.Blocks(title="HDM Demo", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# HDM Demo") |
|
gr.Markdown( |
|
"### Enter a natural language prompt and/or specific tags to generate an image." |
|
) |
|
with gr.Accordion("Introduction", open=False): |
|
gr.Markdown(""" |
|
# HDM: HomeDiffusion Model Project |
|
HDM is a project to implement a series of generative model that can be pretrained at home. |
|
|
|
* Project Source code: https://github.com/KBlueLeaf/HDM |
|
* Model: https://huggingface.co/KBlueLeaf/HDM-xut-340M-anime |
|
|
|
## Usage |
|
This early model used a model trained on anime image set only, |
|
so you should expect to see anime style images only in this demo. |
|
|
|
For prompting, enter danbooru tag prompt to the box "Tag Prompt" with comma seperated and remove the underscore. |
|
enter natural language prompt to the box "Natural Language Prompt" and enter negative prompt to the box "Negative Prompt". |
|
|
|
If you don't want to spent so much effort on prompting, try to keep "Enable TIPO" selected. |
|
|
|
If you don't want to apply any pre-defined format, unselect "Enable TIPO" and "Enable Format". |
|
|
|
## Model Spec |
|
- Backbone: 343M XUT(UViT modified) arch |
|
- Text Encoder: Qwen3 0.6B (596M) |
|
- VAE: EQ-SDXL-VAE, an EQ-VAE finetuned sdxl vae. |
|
|
|
## Pretraining Dataset |
|
- Danbooru 2023 (latest id around 8M) |
|
- Pixiv famous artist set |
|
- some pvc figure photos |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
nl_prompt_box = gr.Textbox( |
|
label="Natural Language Prompt", |
|
placeholder="e.g., A beautiful anime girl standing in a blooming cherry blossom forest", |
|
lines=3, |
|
) |
|
tag_prompt_box = gr.Textbox( |
|
label="Tag Prompt (comma-separated)", |
|
placeholder="e.g., 1girl, solo, long hair, cherry blossoms, school uniform", |
|
lines=3, |
|
) |
|
neg_prompt_box = gr.Textbox( |
|
label="Negative Prompt", |
|
value=( |
|
"llow quality, worst quality, text, signature, jpeg artifacts, bad anatomy, old, early, copyright name, watermark, artist name, signature, weibo username, realistic" |
|
), |
|
lines=3, |
|
) |
|
with gr.Row(): |
|
tipo_enable = gr.Checkbox( |
|
label="Enable TIPO", |
|
value=True, |
|
) |
|
format_enable = gr.Checkbox( |
|
label="Enable Format", |
|
value=True, |
|
) |
|
with gr.Row(): |
|
zoom_slider = gr.Slider( |
|
label="Zoom", minimum=0.5, maximum=2.0, value=1.0, step=0.01 |
|
) |
|
x_shift_slider = gr.Slider( |
|
label="X Shift", minimum=-0.5, maximum=0.5, value=0.0, step=0.01 |
|
) |
|
y_shift_slider = gr.Slider( |
|
label="Y Shift", minimum=-0.5, maximum=0.5, value=0.0, step=0.01 |
|
) |
|
with gr.Column(scale=1): |
|
with gr.Row(): |
|
num_images_slider = gr.Slider( |
|
label="Number of Images", minimum=1, maximum=4, value=1, step=1 |
|
) |
|
steps_slider = gr.Slider( |
|
label="Inference Steps", minimum=1, maximum=50, value=24, step=1 |
|
) |
|
|
|
with gr.Row(): |
|
cfg_slider = gr.Slider( |
|
label="CFG Scale", minimum=1.0, maximum=7.0, value=4.0, step=0.1 |
|
) |
|
seed_input = gr.Number( |
|
label="Seed", |
|
value=-1, |
|
precision=0, |
|
info="Set to -1 for a random seed.", |
|
) |
|
|
|
with gr.Row(): |
|
tread_gamma1_slider = gr.Slider( |
|
label="Tread Gamma 1", |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.0, |
|
step=0.05, |
|
interactive=True, |
|
) |
|
tread_gamma2_slider = gr.Slider( |
|
label="Tread Gamma 2", |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.0, |
|
step=0.05, |
|
interactive=True, |
|
) |
|
|
|
with gr.Row(): |
|
size_slider = gr.Slider( |
|
label="Base Image Size", |
|
minimum=768, |
|
maximum=1280, |
|
value=1024, |
|
step=16, |
|
) |
|
with gr.Row(): |
|
aspect_ratio_box = gr.Textbox( |
|
label="Ratio (W:H)", |
|
value="1:1", |
|
) |
|
fixed_short_edge = gr.Checkbox( |
|
label="Fixed Edge", |
|
value=True, |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
generate_button = gr.Button("Generate", variant="primary") |
|
output_prompt = gr.TextArea( |
|
label="Final Prompt", |
|
show_label=True, |
|
interactive=False, |
|
lines=32, |
|
max_lines=32, |
|
) |
|
with gr.Column(scale=2): |
|
output_gallery = gr.Gallery( |
|
label="Generated Images", |
|
show_label=True, |
|
elem_id="gallery", |
|
columns=2, |
|
rows=3, |
|
height="800px", |
|
) |
|
|
|
generate_button.click( |
|
fn=generate, |
|
inputs=[ |
|
nl_prompt_box, |
|
tag_prompt_box, |
|
neg_prompt_box, |
|
tipo_enable, |
|
format_enable, |
|
num_images_slider, |
|
steps_slider, |
|
cfg_slider, |
|
size_slider, |
|
aspect_ratio_box, |
|
fixed_short_edge, |
|
zoom_slider, |
|
x_shift_slider, |
|
y_shift_slider, |
|
tread_gamma1_slider, |
|
tread_gamma2_slider, |
|
seed_input, |
|
], |
|
outputs=[output_gallery, output_prompt], |
|
show_progress_on=output_gallery, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|