Kohaku-Blueleaf commited on
Commit
154b20a
·
1 Parent(s): d71fa96

use official pipeline

Browse files
Files changed (2) hide show
  1. app.py +81 -187
  2. requirements.txt +7 -2
app.py CHANGED
@@ -1,28 +1,14 @@
1
  import os
2
  import random
3
- import json
4
- from pathlib import Path
5
  from functools import partial
6
 
7
  if os.environ.get("IN_SPACES", None) is not None:
8
  in_spaces = True
9
  import spaces
10
-
11
- os.system(
12
- "pip install git+https://${GIT_USER}:${GIT_TOKEN}@github.com/KohakuBlueleaf/XUT"
13
- )
14
  else:
15
  in_spaces = False
16
-
17
  import gradio as gr
18
- import httpx
19
- import numpy as np
20
  import torch
21
- import torch.nn as nn
22
- import torch.nn.functional as F
23
- from safetensors.torch import load_file
24
- from PIL import Image
25
- from tqdm import trange
26
 
27
  try:
28
  # pre-import triton can avoid diffusers/transformers make import error
@@ -30,18 +16,14 @@ try:
30
  except ImportError:
31
  print("Triton not found, skip pre import")
32
 
33
- torch.set_float32_matmul_precision("high")
34
-
35
  ## HDM model dep
36
  import xut.env
37
-
38
- xut.env.TORCH_COMPILE = False
39
- xut.env.USE_LIGER = True
40
- xut.env.USE_XFORMERS = False
41
- xut.env.USE_XFORMERS_LAYERS = False
42
- from xut.xut import XUDiT
43
- from transformers import Qwen3Model, Qwen2Tokenizer
44
- from diffusers import AutoencoderKL
45
 
46
  ## TIPO
47
  import kgen.models as kgen_models
@@ -49,15 +31,19 @@ import kgen.executor.tipo as tipo
49
  from kgen.formatter import apply_format, seperate_tags
50
 
51
 
 
 
 
52
  DEFAULT_FORMAT = """
53
- <|special|>,
54
- <|characters|>, <|copyrights|>,
55
- <|artist|>,
56
- <|quality|>, <|meta|>, <|rating|>,
57
 
58
  <|general|>,
59
 
60
  <|extended|>.
 
 
61
  """.strip()
62
 
63
 
@@ -73,23 +59,6 @@ def GPU(func=None, duration=None):
73
  return func
74
 
75
 
76
- def download_model(url: str, filepath: str):
77
- """Minimal fast download function"""
78
- if Path(filepath).exists():
79
- print(f"Model already exists at {filepath}")
80
- return
81
-
82
- print(f"Downloading model...")
83
- Path(filepath).parent.mkdir(parents=True, exist_ok=True)
84
-
85
- with httpx.stream("GET", url, follow_redirects=True) as response:
86
- response.raise_for_status()
87
- with open(filepath, "wb") as f:
88
- for chunk in response.iter_bytes(chunk_size=128 * 1024):
89
- f.write(chunk)
90
- print(f"Download completed: {filepath}")
91
-
92
-
93
  def prompt_opt(tags, nl_prompt, aspect_ratio, seed):
94
  meta, operations, general, nl_prompt = tipo.parse_tipo_request(
95
  seperate_tags(tags.split(",")),
@@ -103,100 +72,24 @@ def prompt_opt(tags, nl_prompt, aspect_ratio, seed):
103
  return apply_format(result, DEFAULT_FORMAT).strip().strip(".").strip(",")
104
 
105
 
106
- # --- User's core functions (copied directly) ---
107
- def cfg_wrapper(
108
- prompt: str | list[str],
109
- neg_prompt: str | list[str],
110
- unet: nn.Module, # should be k_diffusion wrapper
111
- te: Qwen3Model,
112
- tokenizer: Qwen2Tokenizer,
113
- cfg_scale: float = 3.0,
114
- ):
115
- prompt_token = {
116
- k: v.to(device)
117
- for k, v in tokenizer(
118
- prompt,
119
- padding="longest",
120
- return_tensors="pt",
121
- ).items()
122
- }
123
- neg_prompt_token = {
124
- k: v.to(device)
125
- for k, v in tokenizer(
126
- neg_prompt,
127
- padding="longest",
128
- return_tensors="pt",
129
- ).items()
130
- }
131
-
132
- emb = te(**prompt_token).last_hidden_state
133
- neg_emb = te(**neg_prompt_token).last_hidden_state
134
-
135
- def cfg_fn(x, t, cfg=cfg_scale):
136
- cond = unet(x, t.expand(x.size(0)), emb).float()
137
- uncond = unet(x, t.expand(x.size(0)), neg_emb).float()
138
- return uncond + (cond - uncond) * cfg
139
-
140
- return cfg_fn
141
-
142
-
143
  print("Loading models, please wait...")
144
  device = torch.device("cuda")
145
 
146
  model = (
147
- XUDiT(**json.load(open("./config/xut-small-1024-tread.json", "r")))
148
- .half()
149
- .requires_grad_(False)
150
- .eval()
151
- .to(device)
152
- )
153
- tokenizer = Qwen2Tokenizer.from_pretrained(
154
- "Qwen/Qwen3-0.6B",
155
- )
156
- te = (
157
- Qwen3Model.from_pretrained(
158
- "Qwen/Qwen3-0.6B", torch_dtype=torch.float16, attn_implementation="sdpa"
159
  )
160
- .half()
161
- .eval()
162
- .requires_grad_(False)
163
  .to(device)
164
  )
165
- vae = (
166
- AutoencoderKL.from_pretrained("KBlueLeaf/EQ-SDXL-VAE")
167
- .half()
168
- .eval()
169
- .requires_grad_(False)
170
- .to(device)
171
- )
172
- vae_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1).to(device)
173
- vae_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1).to(device)
174
-
175
-
176
- if not os.path.exists("./model/model.safetensors"):
177
- model_file = os.environ.get("MODEL_FILE")
178
- os.system(
179
- f"hfutils download -t model -r KBlueLeaf/XUT-demo -f {model_file} -o model/model.safetensors"
180
- )
181
-
182
- state_dict = load_file("./model/model.safetensors")
183
- model_sd = {
184
- k.replace("unet.", ""): v for k, v in state_dict.items() if k.startswith("unet.")
185
- }
186
- model_sd = {k.replace("model.", ""): v for k, v in model_sd.items()}
187
- missing, unexpected = model.load_state_dict(model_sd, strict=False)
188
- if missing:
189
- print(f"Missing keys: {missing}")
190
- if unexpected:
191
- print(f"Unexpected keys: {unexpected}")
192
-
193
 
194
  tipo_model_name, gguf_list = kgen_models.tipo_model_list[0]
195
  kgen_models.load_model(tipo_model_name, device="cuda")
196
  print("Models loaded successfully. UI is ready.")
197
 
198
 
199
- @GPU(duration=5)
200
  @torch.no_grad()
201
  def generate(
202
  nl_prompt: str,
@@ -210,6 +103,9 @@ def generate(
210
  size: int,
211
  aspect_ratio: str,
212
  fixed_short_edge: bool,
 
 
 
213
  seed: int,
214
  progress=gr.Progress(),
215
  ):
@@ -230,7 +126,6 @@ def generate(
230
  final_prompt = tag_prompt + "\n" + nl_prompt
231
 
232
  yield None, final_prompt
233
- all_pil_images = []
234
 
235
  prompts_to_generate = [final_prompt.replace("\n", " ")] * num_images
236
  negative_prompts_to_generate = [negative_prompt] * num_images
@@ -246,12 +141,12 @@ def generate(
246
  w_factor = aspect_ratio**0.5
247
  h_factor = 1 / w_factor
248
 
249
- w = int(size * w_factor / 16) * 2
250
- h = int(size * h_factor / 16) * 2
251
 
252
  print("=" * 100)
253
  print(
254
- f"Generating {num_images} image(s) with seed: {seed} and resolution {w*8}x{h*8}"
255
  )
256
  print("-" * 80)
257
  print(f"Final prompt: {final_prompt}")
@@ -262,54 +157,26 @@ def generate(
262
  prompts_batch = prompts_to_generate
263
  neg_prompts_batch = negative_prompts_to_generate
264
 
265
- # Core logic from the original script
266
- cfg_fn = cfg_wrapper(
267
  prompts_batch,
268
  neg_prompts_batch,
269
- unet=model,
270
- te=te,
271
- tokenizer=tokenizer,
272
  cfg_scale=cfg_scale,
273
- )
274
- xt = torch.randn(num_images, 4, h, w).to(device)
275
-
276
- t = 1.0
277
- dt = 1.0 / steps
278
- with trange(steps, desc="Generating Steps", smoothing=0.05) as cli_prog_bar:
279
- for step in progress.tqdm(list(range(steps)), desc="Generating Steps"):
280
- with torch.autocast(device.type, dtype=torch.float16):
281
- model_pred = cfg_fn(xt, torch.tensor(t, device=device))
282
- xt = xt - dt * model_pred.float()
283
- t -= dt
284
- cli_prog_bar.update(1)
285
-
286
- generated_latents = xt.float()
287
- image_tensors = torch.concat(
288
- [
289
- vae.decode(
290
- (generated_latent[None] * vae_std + vae_mean).half()
291
- ).sample.cpu()
292
- for generated_latent in generated_latents
293
- ]
294
- )
295
 
296
- # Convert tensors to PIL images
297
- for image_tensor in image_tensors:
298
- image = Image.fromarray(
299
- ((image_tensor * 0.5 + 0.5) * 255)
300
- .clamp(0, 255)
301
- .numpy()
302
- .astype(np.uint8)
303
- .transpose(1, 2, 0)
304
- )
305
- all_pil_images.append(image)
306
-
307
- yield all_pil_images, final_prompt
308
 
309
 
310
  # --- Gradio UI Definition ---
311
  with gr.Blocks(title="HDM Demo", theme=gr.themes.Soft()) as demo:
312
- gr.Markdown("# HDM Early Demo")
313
  gr.Markdown(
314
  "### Enter a natural language prompt and/or specific tags to generate an image."
315
  )
@@ -318,9 +185,8 @@ with gr.Blocks(title="HDM Demo", theme=gr.themes.Soft()) as demo:
318
  # HDM: HomeDiffusion Model Project
319
  HDM is a project to implement a series of generative model that can be pretrained at home.
320
 
321
- ## About this Demo
322
- This DEMO used a checkpoint during training to demostrate the functionality of HDM.
323
- Not final model yet.
324
 
325
  ## Usage
326
  This early model used a model trained on anime image set only,
@@ -334,7 +200,7 @@ If you don't want to spent so much effort on prompting, try to keep "Enable TIPO
334
  If you don't want to apply any pre-defined format, unselect "Enable TIPO" and "Enable Format".
335
 
336
  ## Model Spec
337
- - Backbone: 342M custom DiT(UViT modified) arch
338
  - Text Encoder: Qwen3 0.6B (596M)
339
  - VAE: EQ-SDXL-VAE, an EQ-VAE finetuned sdxl vae.
340
 
@@ -359,9 +225,7 @@ If you don't want to apply any pre-defined format, unselect "Enable TIPO" and "E
359
  neg_prompt_box = gr.Textbox(
360
  label="Negative Prompt",
361
  value=(
362
- "low quality, worst quality, "
363
- "jpeg artifacts, bad anatomy, old, early, "
364
- "copyright name, watermark"
365
  ),
366
  lines=3,
367
  )
@@ -374,18 +238,28 @@ If you don't want to apply any pre-defined format, unselect "Enable TIPO" and "E
374
  label="Enable Format",
375
  value=True,
376
  )
 
 
 
 
 
 
 
 
 
 
377
  with gr.Column(scale=1):
378
  with gr.Row():
379
  num_images_slider = gr.Slider(
380
- label="Number of Images", minimum=1, maximum=16, value=1, step=1
381
  )
382
  steps_slider = gr.Slider(
383
- label="Inference Steps", minimum=1, maximum=64, value=32, step=1
384
  )
385
 
386
  with gr.Row():
387
  cfg_slider = gr.Slider(
388
- label="CFG Scale", minimum=1.0, maximum=5.0, value=3.0, step=0.1
389
  )
390
  seed_input = gr.Number(
391
  label="Seed",
@@ -394,13 +268,31 @@ If you don't want to apply any pre-defined format, unselect "Enable TIPO" and "E
394
  info="Set to -1 for a random seed.",
395
  )
396
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  with gr.Row():
398
  size_slider = gr.Slider(
399
  label="Base Image Size",
400
- minimum=384,
401
- maximum=768,
402
- value=512,
403
- step=64,
404
  )
405
  with gr.Row():
406
  aspect_ratio_box = gr.Textbox(
@@ -412,10 +304,9 @@ If you don't want to apply any pre-defined format, unselect "Enable TIPO" and "E
412
  value=True,
413
  )
414
 
415
- generate_button = gr.Button("Generate", variant="primary")
416
-
417
  with gr.Row():
418
  with gr.Column(scale=1):
 
419
  output_prompt = gr.TextArea(
420
  label="Final Prompt",
421
  show_label=True,
@@ -428,7 +319,7 @@ If you don't want to apply any pre-defined format, unselect "Enable TIPO" and "E
428
  label="Generated Images",
429
  show_label=True,
430
  elem_id="gallery",
431
- columns=4,
432
  rows=3,
433
  height="800px",
434
  )
@@ -447,6 +338,9 @@ If you don't want to apply any pre-defined format, unselect "Enable TIPO" and "E
447
  size_slider,
448
  aspect_ratio_box,
449
  fixed_short_edge,
 
 
 
450
  seed_input,
451
  ],
452
  outputs=[output_gallery, output_prompt],
 
1
  import os
2
  import random
 
 
3
  from functools import partial
4
 
5
  if os.environ.get("IN_SPACES", None) is not None:
6
  in_spaces = True
7
  import spaces
 
 
 
 
8
  else:
9
  in_spaces = False
 
10
  import gradio as gr
 
 
11
  import torch
 
 
 
 
 
12
 
13
  try:
14
  # pre-import triton can avoid diffusers/transformers make import error
 
16
  except ImportError:
17
  print("Triton not found, skip pre import")
18
 
 
 
19
  ## HDM model dep
20
  import xut.env
21
+ xut.env.TORCH_COMPILE = True
22
+ xut.env.USE_LIGER = False
23
+ xut.env.USE_VANILLA = False
24
+ xut.env.USE_XFORMERS = True
25
+ xut.env.USE_XFORMERS_LAYERS = True
26
+ from hdm.pipeline import HDMXUTPipeline
 
 
27
 
28
  ## TIPO
29
  import kgen.models as kgen_models
 
31
  from kgen.formatter import apply_format, seperate_tags
32
 
33
 
34
+ torch.set_float32_matmul_precision("high")
35
+
36
+
37
  DEFAULT_FORMAT = """
38
+ <|special|>,
39
+ <|characters|>, <|copyrights|>,
40
+ <|artist|>,
 
41
 
42
  <|general|>,
43
 
44
  <|extended|>.
45
+
46
+ <|quality|>, <|meta|>, <|rating|>
47
  """.strip()
48
 
49
 
 
59
  return func
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def prompt_opt(tags, nl_prompt, aspect_ratio, seed):
63
  meta, operations, general, nl_prompt = tipo.parse_tipo_request(
64
  seperate_tags(tags.split(",")),
 
72
  return apply_format(result, DEFAULT_FORMAT).strip().strip(".").strip(",")
73
 
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  print("Loading models, please wait...")
76
  device = torch.device("cuda")
77
 
78
  model = (
79
+ HDMXUTPipeline.from_pretrained(
80
+ "KBlueLeaf/HDM-xut-340M-anime",
81
+ trust_remote_code=True,
 
 
 
 
 
 
 
 
 
82
  )
83
+ .to(torch.float16)
 
 
84
  .to(device)
85
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  tipo_model_name, gguf_list = kgen_models.tipo_model_list[0]
88
  kgen_models.load_model(tipo_model_name, device="cuda")
89
  print("Models loaded successfully. UI is ready.")
90
 
91
 
92
+ @GPU(duration=10)
93
  @torch.no_grad()
94
  def generate(
95
  nl_prompt: str,
 
103
  size: int,
104
  aspect_ratio: str,
105
  fixed_short_edge: bool,
106
+ zoom: float,
107
+ x_shift: float,
108
+ y_shift: float,
109
  seed: int,
110
  progress=gr.Progress(),
111
  ):
 
126
  final_prompt = tag_prompt + "\n" + nl_prompt
127
 
128
  yield None, final_prompt
 
129
 
130
  prompts_to_generate = [final_prompt.replace("\n", " ")] * num_images
131
  negative_prompts_to_generate = [negative_prompt] * num_images
 
141
  w_factor = aspect_ratio**0.5
142
  h_factor = 1 / w_factor
143
 
144
+ w = int(size * w_factor / 16) * 16
145
+ h = int(size * h_factor / 16) * 16
146
 
147
  print("=" * 100)
148
  print(
149
+ f"Generating {num_images} image(s) with seed: {seed} and resolution {w}x{h}"
150
  )
151
  print("-" * 80)
152
  print(f"Final prompt: {final_prompt}")
 
157
  prompts_batch = prompts_to_generate
158
  neg_prompts_batch = negative_prompts_to_generate
159
 
160
+ images = model(
 
161
  prompts_batch,
162
  neg_prompts_batch,
163
+ num_inference_steps=steps,
 
 
164
  cfg_scale=cfg_scale,
165
+ width=w,
166
+ height=h,
167
+ camera_param={
168
+ "zoom": zoom,
169
+ "x_shift": x_shift,
170
+ "y_shift": y_shift,
171
+ }
172
+ ).images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
+ yield images, final_prompt
 
 
 
 
 
 
 
 
 
 
 
175
 
176
 
177
  # --- Gradio UI Definition ---
178
  with gr.Blocks(title="HDM Demo", theme=gr.themes.Soft()) as demo:
179
+ gr.Markdown("# HDM Demo")
180
  gr.Markdown(
181
  "### Enter a natural language prompt and/or specific tags to generate an image."
182
  )
 
185
  # HDM: HomeDiffusion Model Project
186
  HDM is a project to implement a series of generative model that can be pretrained at home.
187
 
188
+ * Project Source code: https://github.com/KBlueLeaf/HDM
189
+ * Model: https://huggingface.co/KBlueLeaf/HDM-xut-340M-anime
 
190
 
191
  ## Usage
192
  This early model used a model trained on anime image set only,
 
200
  If you don't want to apply any pre-defined format, unselect "Enable TIPO" and "Enable Format".
201
 
202
  ## Model Spec
203
+ - Backbone: 343M XUT(UViT modified) arch
204
  - Text Encoder: Qwen3 0.6B (596M)
205
  - VAE: EQ-SDXL-VAE, an EQ-VAE finetuned sdxl vae.
206
 
 
225
  neg_prompt_box = gr.Textbox(
226
  label="Negative Prompt",
227
  value=(
228
+ "llow quality, worst quality, text, signature, jpeg artifacts, bad anatomy, old, early, copyright name, watermark, artist name, signature, weibo username, realistic"
 
 
229
  ),
230
  lines=3,
231
  )
 
238
  label="Enable Format",
239
  value=True,
240
  )
241
+ with gr.Row():
242
+ zoom_slider = gr.Slider(
243
+ label="Zoom", minimum=0.5, maximum=2.0, value=1.0, step=0.01
244
+ )
245
+ x_shift_slider = gr.Slider(
246
+ label="X Shift", minimum=-0.5, maximum=0.5, value=0.0, step=0.01
247
+ )
248
+ y_shift_slider = gr.Slider(
249
+ label="Y Shift", minimum=-0.5, maximum=0.5, value=0.0, step=0.01
250
+ )
251
  with gr.Column(scale=1):
252
  with gr.Row():
253
  num_images_slider = gr.Slider(
254
+ label="Number of Images", minimum=1, maximum=4, value=1, step=1
255
  )
256
  steps_slider = gr.Slider(
257
+ label="Inference Steps", minimum=1, maximum=50, value=24, step=1
258
  )
259
 
260
  with gr.Row():
261
  cfg_slider = gr.Slider(
262
+ label="CFG Scale", minimum=1.0, maximum=7.0, value=4.0, step=0.1
263
  )
264
  seed_input = gr.Number(
265
  label="Seed",
 
268
  info="Set to -1 for a random seed.",
269
  )
270
 
271
+ with gr.Row():
272
+ tread_gamma1 = gr.Slider(
273
+ label="Tread Gamma 1",
274
+ minimum=0.0,
275
+ maximum=1.0,
276
+ value=0.0,
277
+ step=0.05,
278
+ interactive=True,
279
+ )
280
+ tread_gamma1_slider = gr.Slider(
281
+ label="Tread Gamma 2",
282
+ minimum=0.0,
283
+ maximum=1.0,
284
+ value=0.25,
285
+ step=0.05,
286
+ interactive=True,
287
+ )
288
+
289
  with gr.Row():
290
  size_slider = gr.Slider(
291
  label="Base Image Size",
292
+ minimum=768,
293
+ maximum=1280,
294
+ value=1024,
295
+ step=16,
296
  )
297
  with gr.Row():
298
  aspect_ratio_box = gr.Textbox(
 
304
  value=True,
305
  )
306
 
 
 
307
  with gr.Row():
308
  with gr.Column(scale=1):
309
+ generate_button = gr.Button("Generate", variant="primary")
310
  output_prompt = gr.TextArea(
311
  label="Final Prompt",
312
  show_label=True,
 
319
  label="Generated Images",
320
  show_label=True,
321
  elem_id="gallery",
322
+ columns=2,
323
  rows=3,
324
  height="800px",
325
  )
 
338
  size_slider,
339
  aspect_ratio_box,
340
  fixed_short_edge,
341
+ zoom_slider,
342
+ x_shift_slider,
343
+ y_shift_slider,
344
  seed_input,
345
  ],
346
  outputs=[output_gallery, output_prompt],
requirements.txt CHANGED
@@ -1,7 +1,12 @@
 
 
 
 
 
 
1
  transformers
2
  diffusers
3
  tqdm
4
- torch
5
  pillow
6
  tipo-kgen
7
  safetensors
@@ -11,4 +16,4 @@ httpx
11
  einops
12
  hfutils[transfer]
13
  sentencepiece
14
- https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.13-cu124/llama_cpp_python-0.3.13-cp310-cp310-linux_x86_64.whl
 
1
+ --index-url https://download.pytorch.org/whl/cu128
2
+ --extra-index-url https://pypi.org/simple/
3
+ torch
4
+ torchvision
5
+ xformers
6
+ accelerate
7
  transformers
8
  diffusers
9
  tqdm
 
10
  pillow
11
  tipo-kgen
12
  safetensors
 
16
  einops
17
  hfutils[transfer]
18
  sentencepiece
19
+ git+https://github.com/KohakuBlueleaf/HDM