WYBar commited on
Commit
d9c4eb2
·
1 Parent(s): 13ab714

generator device + seed_everything + custom_pipeline

Browse files
Files changed (3) hide show
  1. app.py +39 -62
  2. custom_pipeline.py +348 -67
  3. requirements.txt +2 -1
app.py CHANGED
@@ -21,6 +21,7 @@ import base64
21
  import os
22
  import time
23
  import re
 
24
 
25
  from transformers import (
26
  AutoTokenizer,
@@ -37,6 +38,15 @@ class StopAtSpecificTokenCriteria(StoppingCriteria):
37
  @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
38
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
39
  return input_ids[0][-1].detach().cpu().numpy() in self.token_id_list
 
 
 
 
 
 
 
 
 
40
 
41
  def ensure_space_after_period(input_string):
42
  # 去除多余的空格
@@ -232,13 +242,8 @@ def construction_all():
232
  global model
233
  global quantizer
234
  global tokenizer
235
- global pipeline
236
- global transp_vae
237
  from modeling_crello import CrelloModel, CrelloModelConfig
238
  from quantizer import get_quantizer
239
- from custom_model_mmdit import CustomFluxTransformer2DModel
240
- from custom_model_transp_vae import AutoencoderKLTransformerTraining as CustomVAE
241
- from custom_pipeline import CustomFluxPipelineCfg
242
 
243
  params_dict = {
244
  "input_model": "/openseg_blob/v-sirui/temporary/2024-02-21/Layout_train/COLEv2/Design_LLM/checkpoint/Meta-Llama-3-8B",
@@ -314,13 +319,28 @@ def construction_all():
314
  for token in added_special_tokens_list:
315
  quantizer.additional_special_tokens.add(token)
316
 
 
 
 
 
 
 
317
  transformer = CustomFluxTransformer2DModel.from_pretrained(
318
  "WYBar/ART_test_weights",
319
  subfolder="fused_transformer",
320
  torch_dtype=torch.bfloat16,
321
- # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
322
  )
323
-
 
 
 
 
 
 
 
 
 
324
  transp_vae = CustomVAE.from_pretrained(
325
  "WYBar/ART_test_weights",
326
  subfolder="custom_vae",
@@ -328,16 +348,7 @@ def construction_all():
328
  use_safetensors=True,
329
  # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
330
  )
331
-
332
- token = os.environ.get("HF_TOKEN")
333
- pipeline = CustomFluxPipelineCfg.from_pretrained(
334
- "black-forest-labs/FLUX.1-dev",
335
- transformer=transformer,
336
- torch_dtype=torch.bfloat16,
337
- token=token,
338
- # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
339
- ).to("cuda")
340
- pipeline.enable_model_cpu_offload(gpu_id=0) # Save GPU memory
341
 
342
  print(f"before .to(device):{model.device} {model.lm.device} {pipeline.device}")
343
  model = model.to("cuda")
@@ -421,6 +432,7 @@ def test_one_sample(validation_box, validation_prompt, true_gs, inference_steps,
421
  num_layers=len(validation_box),
422
  guidance_scale=4.0,
423
  num_inference_steps=inference_steps,
 
424
  transparent_decoder=transp_vae,
425
  true_gs=true_gs
426
  )
@@ -440,25 +452,29 @@ def test_one_sample(validation_box, validation_prompt, true_gs, inference_steps,
440
 
441
  def svg_test_one_sample(validation_prompt, validation_box_str, seed, true_gs, inference_steps, pipeline, transp_vae):
442
  print(f"svg_test_one_sample {model.device} {model.lm.device} {pipeline.device}")
443
- generator = torch.Generator().manual_seed(seed)
 
444
  try:
445
- validation_box = ast.literal_eval(validation_box_str)
 
 
 
446
  except Exception as e:
447
  return [f"Error parsing validation_box: {e}"]
 
448
  if not isinstance(validation_box, list) or not all(isinstance(t, tuple) and len(t) == 4 for t in validation_box):
449
  return ["validation_box must be a list of tuples, each of length 4."]
450
-
451
  validation_box = adjust_validation_box(validation_box)
452
 
453
  print("result_images = test_one_sample")
454
  result_images = test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, pipeline, generator, transp_vae)
455
  print("after result_images = test_one_sample")
456
  svg_img = pngs_to_svg(result_images[1:])
457
-
458
  svg_file_path = './image.svg'
459
  os.makedirs(os.path.dirname(svg_file_path), exist_ok=True)
460
  with open(svg_file_path, 'w', encoding='utf-8') as f:
461
- f.write(svg_img)
462
 
463
  if not isinstance(result_images, list):
464
  raise TypeError("result_images 必须是一个列表")
@@ -475,7 +491,7 @@ def svg_test_one_sample(validation_prompt, validation_box_str, seed, true_gs, in
475
  def process_svg(text_input, tuple_input, seed, true_gs, inference_steps):
476
  print(f"precess_svg {model.device} {model.lm.device} {pipeline.device}")
477
  result_images = []
478
- result_images, svg_file_path = svg_test_one_sample(text_input, tuple_input, seed, true_gs, inference_steps, pipeline=pipeline, transp_vae=transp_vae)
479
  # result_images, svg_file_path = gradio_test_one_sample_partial(text_input, tuple_input, seed, true_gs, inference_steps)
480
 
481
  url, unique_filename = upload_to_github(file_path=svg_file_path)
@@ -543,45 +559,6 @@ def main():
543
  construction_all()
544
  print(f"after construction_all:{model.device} {model.lm.device} {pipeline.device}")
545
 
546
- # def process_preddate(intention, generate_method='v1'):
547
- # list_box = [(0, 0, 512, 512), (0, 0, 512, 512), (136, 184, 512, 512), (144, 0, 512, 512), (0, 0, 328, 136), (160, 112, 512, 360), (168, 112, 512, 360), (40, 232, 112, 296), (32, 88, 248, 176), (48, 424, 144, 448), (48, 464, 144, 488), (240, 464, 352, 488), (384, 464, 488, 488), (48, 480, 144, 504), (240, 480, 360, 504), (456, 0, 512, 56), (0, 0, 56, 40), (440, 0, 512, 40), (0, 24, 48, 88), (48, 168, 168, 240)]
548
- # wholecaption = "Design an engaging and vibrant recruitment advertisement for our company. The image should feature three animated characters in a modern cityscape, depicting a dynamic and collaborative work environment. Incorporate a light bulb graphic with a question mark, symbolizing innovation, creativity, and problem-solving. Use bold text to announce \"WE ARE RECRUITING\" and provide the company's social media handle \"@reallygreatsite\" and a contact phone number \"+123-456-7890\" for interested individuals. The overall design should be playful and youthful, attracting potential recruits who are innovative and eager to contribute to a lively team."
549
- # json_file = "/home/wyb/openseg_blob/v-yanbin/GradioDemo/LLM-For-Layout-Planning/inference_test.json"
550
- # return wholecaption, str(list_box), json_file
551
-
552
- # pipeline, transp_vae = construction()
553
-
554
- # gradio_test_one_sample_partial = partial(
555
- # svg_test_one_sample,
556
- # pipeline=pipeline,
557
- # transp_vae=transp_vae,
558
- # )
559
-
560
- # def process_svg(text_input, tuple_input, seed, true_gs, inference_steps):
561
- # print("precess_svg")
562
- # result_images = []
563
- # result_images, svg_file_path = svg_test_one_sample(text_input, tuple_input, seed, true_gs, inference_steps, pipeline=pipeline, transp_vae=transp_vae)
564
- # # result_images, svg_file_path = gradio_test_one_sample_partial(text_input, tuple_input, seed, true_gs, inference_steps)
565
-
566
- # url, unique_filename = upload_to_github(file_path=svg_file_path)
567
- # unique_filename = f'{unique_filename}'
568
-
569
- # if url != None:
570
- # print(f"File uploaded to: {url}")
571
- # svg_editor = f"""
572
- # <iframe src="https://svgedit.netlify.app/editor/index.html?\
573
- # storagePrompt=false&url={url}" \
574
- # width="100%", height="800px"></iframe>
575
- # """
576
- # else:
577
- # print('upload_to_github FAILED!')
578
- # svg_editor = f"""
579
- # <iframe src="https://svgedit.netlify.app/editor/index.html" \
580
- # width="100%", height="800px"></iframe>
581
- # """
582
-
583
- # return result_images, svg_file_path, svg_editor
584
-
585
  def one_click_generate(intention_input, temperature, top_p, seed, true_gs, inference_steps):
586
  # 首先调用process_preddate
587
  list_box_output, intention_input, list_box_output = process_preddate(intention_input, temperature, top_p)
 
21
  import os
22
  import time
23
  import re
24
+ import random
25
 
26
  from transformers import (
27
  AutoTokenizer,
 
38
  @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
39
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
40
  return input_ids[0][-1].detach().cpu().numpy() in self.token_id_list
41
+
42
+ def seed_everything(seed):
43
+ random.seed(seed)
44
+ np.random.seed(seed)
45
+ torch.manual_seed(seed)
46
+ if torch.cuda.is_available():
47
+ torch.cuda.manual_seed(seed)
48
+ torch.cuda.manual_seed_all(seed)
49
+ torch.backends.cudnn.deterministic = True
50
 
51
  def ensure_space_after_period(input_string):
52
  # 去除多余的空格
 
242
  global model
243
  global quantizer
244
  global tokenizer
 
 
245
  from modeling_crello import CrelloModel, CrelloModelConfig
246
  from quantizer import get_quantizer
 
 
 
247
 
248
  params_dict = {
249
  "input_model": "/openseg_blob/v-sirui/temporary/2024-02-21/Layout_train/COLEv2/Design_LLM/checkpoint/Meta-Llama-3-8B",
 
319
  for token in added_special_tokens_list:
320
  quantizer.additional_special_tokens.add(token)
321
 
322
+ global pipeline
323
+ global transp_vae
324
+ seed_everything(42)
325
+ from custom_model_mmdit import CustomFluxTransformer2DModel
326
+ from custom_pipeline import CustomFluxPipelineCfg
327
+
328
  transformer = CustomFluxTransformer2DModel.from_pretrained(
329
  "WYBar/ART_test_weights",
330
  subfolder="fused_transformer",
331
  torch_dtype=torch.bfloat16,
332
+ cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
333
  )
334
+
335
+ pipeline = CustomFluxPipelineCfg.from_pretrained(
336
+ "black-forest-labs/FLUX.1-dev",
337
+ transformer=transformer,
338
+ torch_dtype=torch.bfloat16,
339
+ cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
340
+ ).to("cuda")
341
+ # pipeline.enable_model_cpu_offload(gpu_id=0) # save vram
342
+
343
+ from custom_model_transp_vae import AutoencoderKLTransformerTraining as CustomVAE
344
  transp_vae = CustomVAE.from_pretrained(
345
  "WYBar/ART_test_weights",
346
  subfolder="custom_vae",
 
348
  use_safetensors=True,
349
  # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
350
  )
351
+ transp_vae.eval()
 
 
 
 
 
 
 
 
 
352
 
353
  print(f"before .to(device):{model.device} {model.lm.device} {pipeline.device}")
354
  model = model.to("cuda")
 
432
  num_layers=len(validation_box),
433
  guidance_scale=4.0,
434
  num_inference_steps=inference_steps,
435
+ sdxl_vae=transp_vae,
436
  transparent_decoder=transp_vae,
437
  true_gs=true_gs
438
  )
 
452
 
453
  def svg_test_one_sample(validation_prompt, validation_box_str, seed, true_gs, inference_steps, pipeline, transp_vae):
454
  print(f"svg_test_one_sample {model.device} {model.lm.device} {pipeline.device}")
455
+ # generator = torch.Generator().manual_seed(seed)
456
+ generator = torch.Generator(device=torch.device("cuda", index=0)).manual_seed(seed)
457
  try:
458
+ if isinstance(validation_box_str, (list, tuple)):
459
+ validation_box = validation_box_str
460
+ else:
461
+ validation_box = ast.literal_eval(validation_box_str)
462
  except Exception as e:
463
  return [f"Error parsing validation_box: {e}"]
464
+
465
  if not isinstance(validation_box, list) or not all(isinstance(t, tuple) and len(t) == 4 for t in validation_box):
466
  return ["validation_box must be a list of tuples, each of length 4."]
 
467
  validation_box = adjust_validation_box(validation_box)
468
 
469
  print("result_images = test_one_sample")
470
  result_images = test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, pipeline, generator, transp_vae)
471
  print("after result_images = test_one_sample")
472
  svg_img = pngs_to_svg(result_images[1:])
473
+
474
  svg_file_path = './image.svg'
475
  os.makedirs(os.path.dirname(svg_file_path), exist_ok=True)
476
  with open(svg_file_path, 'w', encoding='utf-8') as f:
477
+ f.write(svg_img)
478
 
479
  if not isinstance(result_images, list):
480
  raise TypeError("result_images 必须是一个列表")
 
491
  def process_svg(text_input, tuple_input, seed, true_gs, inference_steps):
492
  print(f"precess_svg {model.device} {model.lm.device} {pipeline.device}")
493
  result_images = []
494
+ result_images, svg_file_path = gradio_test_one_sample(text_input, tuple_input, seed, true_gs, inference_steps, pipeline=pipeline, transp_vae=transp_vae)
495
  # result_images, svg_file_path = gradio_test_one_sample_partial(text_input, tuple_input, seed, true_gs, inference_steps)
496
 
497
  url, unique_filename = upload_to_github(file_path=svg_file_path)
 
559
  construction_all()
560
  print(f"after construction_all:{model.device} {model.lm.device} {pipeline.device}")
561
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562
  def one_click_generate(intention_input, temperature, top_p, seed, true_gs, inference_steps):
563
  # 首先调用process_preddate
564
  list_box_output, intention_input, list_box_output = process_preddate(intention_input, temperature, top_p)
custom_pipeline.py CHANGED
@@ -1,16 +1,18 @@
 
 
1
  import numpy as np
 
2
  from typing import Any, Callable, Dict, List, Optional, Union
3
 
4
- import torch
5
- import torch.nn as nn
6
 
7
- from diffusers.utils.torch_utils import randn_tensor
8
  from diffusers.utils import is_torch_xla_available, logging
9
- from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
10
- from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps, FluxPipeline
11
 
12
  if is_torch_xla_available():
13
  import torch_xla.core.xla_model as xm # type: ignore
 
14
  XLA_AVAILABLE = True
15
  else:
16
  XLA_AVAILABLE = False
@@ -55,7 +57,6 @@ def _get_clip_prompt_embeds(
55
 
56
  return prompt_embeds
57
 
58
-
59
  def _get_t5_prompt_embeds(
60
  tokenizer,
61
  text_encoder,
@@ -111,6 +112,7 @@ def encode_prompt(
111
  prompt = [prompt] if isinstance(prompt, str) else prompt
112
  prompt_2 = prompt_2 or prompt
113
  prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
 
114
 
115
  # We only use the pooled prompt output from the CLIPTextModel
116
  pooled_prompt_embeds = _get_clip_prompt_embeds(
@@ -469,73 +471,328 @@ class CustomFluxPipeline(FluxPipeline):
469
  return FluxPipelineOutput(images=image), result_list, vis_list
470
 
471
 
472
- class CustomFluxPipelineCfg(FluxPipeline):
473
 
474
- @staticmethod
475
- def _prepare_latent_image_ids(height, width, list_layer_box, device, dtype):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
 
477
- latent_image_ids_list = []
478
- for layer_idx in range(len(list_layer_box)):
479
- if list_layer_box[layer_idx] == None:
480
- continue
481
- else:
482
- latent_image_ids = torch.zeros(height // 2, width // 2, 3) # [h/2, w/2, 3]
483
- latent_image_ids[..., 0] = layer_idx # use the first dimension for layer representation
484
- latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
485
- latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
 
487
- x1, y1, x2, y2 = list_layer_box[layer_idx]
488
- x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16
489
- latent_image_ids = latent_image_ids[y1:y2, x1:x2, :]
490
 
491
- latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
492
- latent_image_ids = latent_image_ids.reshape(
493
- latent_image_id_height * latent_image_id_width, latent_image_id_channels
494
- )
 
495
 
496
- latent_image_ids_list.append(latent_image_ids)
 
497
 
498
- full_latent_image_ids = torch.cat(latent_image_ids_list, dim=0)
 
 
 
 
 
 
 
 
 
 
499
 
500
- return full_latent_image_ids.to(device=device, dtype=dtype)
 
 
501
 
502
- def prepare_latents(
503
- self,
504
- batch_size,
505
- num_layers,
506
- num_channels_latents,
507
- height,
508
- width,
509
- list_layer_box,
510
- dtype,
511
- device,
512
- generator,
513
- latents=None,
514
- ):
515
- height = 2 * (int(height) // self.vae_scale_factor)
516
- width = 2 * (int(width) // self.vae_scale_factor)
517
 
518
- shape = (batch_size, num_layers, num_channels_latents, height, width)
519
 
520
- if latents is not None:
521
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
522
- return latents.to(device=device, dtype=dtype), latent_image_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
 
524
- if isinstance(generator, list) and len(generator) != batch_size:
525
- raise ValueError(
526
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
527
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
528
- )
 
 
 
 
 
 
 
 
 
529
 
530
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # [bs, n_layers, c_latent, h, w]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
 
532
- latent_image_ids = self._prepare_latent_image_ids(height, width, list_layer_box, device, dtype)
 
 
 
 
 
533
 
534
- return latents, latent_image_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
 
536
  @torch.no_grad()
537
  def __call__(
538
  self,
 
 
539
  prompt: Union[str, List[str]] = None,
540
  prompt_2: Optional[Union[str, List[str]]] = None,
541
  validation_box: List[tuple] = None,
@@ -557,6 +814,7 @@ class CustomFluxPipelineCfg(FluxPipeline):
557
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
558
  max_sequence_length: int = 512,
559
  num_layers: int = 5,
 
560
  transparent_decoder: nn.Module = None,
561
  ):
562
  r"""
@@ -703,9 +961,22 @@ class CustomFluxPipelineCfg(FluxPipeline):
703
  latents,
704
  )
705
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706
  # 5. Prepare timesteps
707
  sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
708
- image_seq_len = latent_image_ids.shape[0]
709
  mu = calculate_shift(
710
  image_seq_len,
711
  self.scheduler.config.base_image_seq_len,
@@ -772,6 +1043,16 @@ class CustomFluxPipelineCfg(FluxPipeline):
772
  latents_dtype = latents.dtype
773
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
774
 
 
 
 
 
 
 
 
 
 
 
775
  if latents.dtype != latents_dtype:
776
  if torch.backends.mps.is_available():
777
  # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
@@ -794,12 +1075,12 @@ class CustomFluxPipelineCfg(FluxPipeline):
794
  xm.mark_step()
795
 
796
  # create a grey latent
797
- bs, n_layers, channel_latent, height, width = latents.shape
798
 
799
- pixel_grey = torch.zeros(size=(bs*n_layers, 3, height*8, width*8), device=latents.device, dtype=latents.dtype)
800
  latent_grey = self.vae.encode(pixel_grey).latent_dist.sample()
801
  latent_grey = (latent_grey - self.vae.config.shift_factor) * self.vae.config.scaling_factor
802
- latent_grey = latent_grey.view(bs, n_layers, channel_latent, height, width) # [bs, n_layers, c_latent, h, w]
803
 
804
  # fill in the latents
805
  for layer_idx in range(latent_grey.shape[1]):
@@ -815,22 +1096,22 @@ class CustomFluxPipelineCfg(FluxPipeline):
815
 
816
  else:
817
  latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
818
- latents = latents.reshape(bs * n_layers, channel_latent, height, width)
819
  latents_segs = torch.split(latents, 16, dim=0) ### split latents by 16 to avoid odd purple output
820
  image_segs = [self.vae.decode(latents_seg, return_dict=False)[0] for latents_seg in latents_segs]
821
  image = torch.cat(image_segs, dim=0)
822
- if transparent_decoder is not None:
823
- transparent_decoder = transparent_decoder.to(dtype=image.dtype, device=image.device)
824
 
825
- decoded_fg, decoded_alpha = transparent_decoder(latents, [validation_box])
826
- decoded_alpha = (decoded_alpha + 1.0) / 2.0
827
- decoded_alpha = torch.clamp(decoded_alpha, min=0.0, max=1.0).permute(0, 2, 3, 1)
828
 
829
  decoded_fg = (decoded_fg + 1.0) / 2.0
830
- decoded_fg = torch.clamp(decoded_fg, min=0.0, max=1.0).permute(0, 2, 3, 1)
831
 
832
  vis_list = None
833
- png = torch.cat([decoded_fg, decoded_alpha], dim=3)
834
  result_list = (png * 255.0).detach().cpu().float().numpy().clip(0, 255).astype(np.uint8)
835
  else:
836
  result_list, vis_list = None, None
 
1
+ import torch
2
+ import torch.nn as nn
3
  import numpy as np
4
+ import math
5
  from typing import Any, Callable, Dict, List, Optional, Union
6
 
7
+ from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps, FluxPipeline
8
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
9
 
 
10
  from diffusers.utils import is_torch_xla_available, logging
11
+ from diffusers.utils.torch_utils import randn_tensor
 
12
 
13
  if is_torch_xla_available():
14
  import torch_xla.core.xla_model as xm # type: ignore
15
+
16
  XLA_AVAILABLE = True
17
  else:
18
  XLA_AVAILABLE = False
 
57
 
58
  return prompt_embeds
59
 
 
60
  def _get_t5_prompt_embeds(
61
  tokenizer,
62
  text_encoder,
 
112
  prompt = [prompt] if isinstance(prompt, str) else prompt
113
  prompt_2 = prompt_2 or prompt
114
  prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
115
+ batch_size = len(prompt)
116
 
117
  # We only use the pooled prompt output from the CLIPTextModel
118
  pooled_prompt_embeds = _get_clip_prompt_embeds(
 
471
  return FluxPipelineOutput(images=image), result_list, vis_list
472
 
473
 
474
+ class CustomFluxPipelineCfg(CustomFluxPipeline):
475
 
476
+ @torch.no_grad()
477
+ def __call__(
478
+ self,
479
+ prompt: Union[str, List[str]] = None,
480
+ prompt_2: Optional[Union[str, List[str]]] = None,
481
+ validation_box: List[tuple] = None,
482
+ height: Optional[int] = None,
483
+ width: Optional[int] = None,
484
+ num_inference_steps: int = 28,
485
+ timesteps: List[int] = None,
486
+ guidance_scale: float = 3.5,
487
+ true_gs: float = 3.5,
488
+ num_images_per_prompt: Optional[int] = 1,
489
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
490
+ latents: Optional[torch.FloatTensor] = None,
491
+ prompt_embeds: Optional[torch.FloatTensor] = None,
492
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
493
+ output_type: Optional[str] = "pil",
494
+ return_dict: bool = True,
495
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
496
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
497
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
498
+ max_sequence_length: int = 512,
499
+ num_layers: int = 5,
500
+ sdxl_vae: nn.Module = None,
501
+ transparent_decoder: nn.Module = None,
502
+ ):
503
+ r"""
504
+ Function invoked when calling the pipeline for generation.
505
 
506
+ Args:
507
+ prompt (`str` or `List[str]`, *optional*):
508
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
509
+ instead.
510
+ prompt_2 (`str` or `List[str]`, *optional*):
511
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
512
+ will be used instead
513
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
514
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
515
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
516
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
517
+ num_inference_steps (`int`, *optional*, defaults to 50):
518
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
519
+ expense of slower inference.
520
+ timesteps (`List[int]`, *optional*):
521
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
522
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
523
+ passed will be used. Must be in descending order.
524
+ guidance_scale (`float`, *optional*, defaults to 7.0):
525
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
526
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
527
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
528
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
529
+ usually at the expense of lower image quality.
530
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
531
+ The number of images to generate per prompt.
532
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
533
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
534
+ to make generation deterministic.
535
+ latents (`torch.FloatTensor`, *optional*):
536
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
537
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
538
+ tensor will ge generated by sampling using the supplied random `generator`.
539
+ prompt_embeds (`torch.FloatTensor`, *optional*):
540
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
541
+ provided, text embeddings will be generated from `prompt` input argument.
542
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
543
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
544
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
545
+ output_type (`str`, *optional*, defaults to `"pil"`):
546
+ The output format of the generate image. Choose between
547
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
548
+ return_dict (`bool`, *optional*, defaults to `True`):
549
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
550
+ joint_attention_kwargs (`dict`, *optional*):
551
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
552
+ `self.processor` in
553
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
554
+ callback_on_step_end (`Callable`, *optional*):
555
+ A function that calls at the end of each denoising steps during the inference. The function is called
556
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
557
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
558
+ `callback_on_step_end_tensor_inputs`.
559
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
560
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
561
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
562
+ `._callback_tensor_inputs` attribute of your pipeline class.
563
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
564
 
565
+ Examples:
 
 
566
 
567
+ Returns:
568
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
569
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
570
+ images.
571
+ """
572
 
573
+ height = height or self.default_sample_size * self.vae_scale_factor
574
+ width = width or self.default_sample_size * self.vae_scale_factor
575
 
576
+ # 1. Check inputs. Raise error if not correct
577
+ self.check_inputs(
578
+ prompt,
579
+ prompt_2,
580
+ height,
581
+ width,
582
+ prompt_embeds=prompt_embeds,
583
+ pooled_prompt_embeds=pooled_prompt_embeds,
584
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
585
+ max_sequence_length=max_sequence_length,
586
+ )
587
 
588
+ self._guidance_scale = guidance_scale
589
+ self._joint_attention_kwargs = joint_attention_kwargs
590
+ self._interrupt = False
591
 
592
+ # 2. Define call parameters
593
+ if prompt is not None and isinstance(prompt, str):
594
+ batch_size = 1
595
+ elif prompt is not None and isinstance(prompt, list):
596
+ batch_size = len(prompt)
597
+ else:
598
+ batch_size = prompt_embeds.shape[0]
 
 
 
 
 
 
 
 
599
 
600
+ device = self._execution_device
601
 
602
+ lora_scale = (
603
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
604
+ )
605
+ (
606
+ prompt_embeds,
607
+ pooled_prompt_embeds,
608
+ text_ids,
609
+ ) = self.encode_prompt(
610
+ prompt=prompt,
611
+ prompt_2=prompt_2,
612
+ prompt_embeds=prompt_embeds,
613
+ pooled_prompt_embeds=pooled_prompt_embeds,
614
+ device=device,
615
+ num_images_per_prompt=num_images_per_prompt,
616
+ max_sequence_length=max_sequence_length,
617
+ lora_scale=lora_scale,
618
+ )
619
+ (
620
+ neg_prompt_embeds,
621
+ neg_pooled_prompt_embeds,
622
+ neg_text_ids,
623
+ ) = self.encode_prompt(
624
+ prompt="",
625
+ prompt_2=None,
626
+ device=device,
627
+ num_images_per_prompt=num_images_per_prompt,
628
+ max_sequence_length=max_sequence_length,
629
+ lora_scale=lora_scale,
630
+ )
631
 
632
+ # 4. Prepare latent variables
633
+ num_channels_latents = self.transformer.config.in_channels // 4
634
+ latents, latent_image_ids = self.prepare_latents(
635
+ batch_size * num_images_per_prompt,
636
+ num_layers,
637
+ num_channels_latents,
638
+ height,
639
+ width,
640
+ validation_box,
641
+ prompt_embeds.dtype,
642
+ device,
643
+ generator,
644
+ latents,
645
+ )
646
 
647
+ # 5. Prepare timesteps
648
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
649
+ image_seq_len = latent_image_ids.shape[0] # ???
650
+ mu = calculate_shift(
651
+ image_seq_len,
652
+ self.scheduler.config.base_image_seq_len,
653
+ self.scheduler.config.max_image_seq_len,
654
+ self.scheduler.config.base_shift,
655
+ self.scheduler.config.max_shift,
656
+ )
657
+ timesteps, num_inference_steps = retrieve_timesteps(
658
+ self.scheduler,
659
+ num_inference_steps,
660
+ device,
661
+ timesteps,
662
+ sigmas,
663
+ mu=mu,
664
+ )
665
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
666
+ self._num_timesteps = len(timesteps)
667
 
668
+ # handle guidance
669
+ if self.transformer.config.guidance_embeds:
670
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
671
+ guidance = guidance.expand(latents.shape[0])
672
+ else:
673
+ guidance = None
674
 
675
+ # 6. Denoising loop
676
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
677
+ for i, t in enumerate(timesteps):
678
+ if self.interrupt:
679
+ continue
680
+
681
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
682
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
683
+
684
+ noise_pred = self.transformer(
685
+ hidden_states=latents,
686
+ list_layer_box=validation_box,
687
+ timestep=timestep / 1000,
688
+ guidance=guidance,
689
+ pooled_projections=pooled_prompt_embeds,
690
+ encoder_hidden_states=prompt_embeds,
691
+ txt_ids=text_ids,
692
+ img_ids=latent_image_ids,
693
+ joint_attention_kwargs=self.joint_attention_kwargs,
694
+ return_dict=False,
695
+ )[0]
696
+
697
+ neg_noise_pred = self.transformer(
698
+ hidden_states=latents,
699
+ list_layer_box=validation_box,
700
+ timestep=timestep / 1000,
701
+ guidance=guidance,
702
+ pooled_projections=neg_pooled_prompt_embeds,
703
+ encoder_hidden_states=neg_prompt_embeds,
704
+ txt_ids=neg_text_ids,
705
+ img_ids=latent_image_ids,
706
+ joint_attention_kwargs=self.joint_attention_kwargs,
707
+ return_dict=False,
708
+ )[0]
709
+
710
+ noise_pred = neg_noise_pred + true_gs * (noise_pred - neg_noise_pred)
711
+
712
+ # compute the previous noisy sample x_t -> x_t-1
713
+ latents_dtype = latents.dtype
714
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
715
+
716
+ if latents.dtype != latents_dtype:
717
+ if torch.backends.mps.is_available():
718
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
719
+ latents = latents.to(latents_dtype)
720
+
721
+ if callback_on_step_end is not None:
722
+ callback_kwargs = {}
723
+ for k in callback_on_step_end_tensor_inputs:
724
+ callback_kwargs[k] = locals()[k]
725
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
726
+
727
+ latents = callback_outputs.pop("latents", latents)
728
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
729
+
730
+ # call the callback, if provided
731
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
732
+ progress_bar.update()
733
+
734
+ if XLA_AVAILABLE:
735
+ xm.mark_step()
736
+
737
+ # create a grey latent
738
+ bs, n_frames, channel_latent, height, width = latents.shape
739
+
740
+ pixel_grey = torch.zeros(size=(bs*n_frames, 3, height*8, width*8), device=latents.device, dtype=latents.dtype)
741
+ latent_grey = self.vae.encode(pixel_grey).latent_dist.sample()
742
+ latent_grey = (latent_grey - self.vae.config.shift_factor) * self.vae.config.scaling_factor
743
+ latent_grey = latent_grey.view(bs, n_frames, channel_latent, height, width) # [bs, f, c_latent, h, w]
744
+
745
+ # fill in the latents
746
+ for layer_idx in range(latent_grey.shape[1]):
747
+ if validation_box[layer_idx] == None:
748
+ continue
749
+ x1, y1, x2, y2 = validation_box[layer_idx]
750
+ x1, y1, x2, y2 = x1 // 8, y1 // 8, x2 // 8, y2 // 8
751
+ latent_grey[:, layer_idx, :, y1:y2, x1:x2] = latents[:, layer_idx, :, y1:y2, x1:x2]
752
+ latents = latent_grey
753
+
754
+ if output_type == "latent":
755
+ image = latents
756
+
757
+ else:
758
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
759
+ latents = latents.reshape(bs * n_frames, channel_latent, height, width)
760
+ latents_segs = torch.split(latents, 16, dim=0) ### split latents by 16 to avoid odd purple output
761
+ image_segs = [self.vae.decode(latents_seg, return_dict=False)[0] for latents_seg in latents_segs]
762
+ image = torch.cat(image_segs, dim=0)
763
+ if sdxl_vae is not None:
764
+ sdxl_vae = sdxl_vae.to(dtype=image.dtype, device=image.device)
765
+
766
+ decoded_fg, decoded_alpha = sdxl_vae(latents, [validation_box])
767
+ decoded_alpha = (decoded_alpha + 1.0) / 2.0 #torch.Size([5, 1, 1024, 1024])
768
+ decoded_alpha = torch.clamp(decoded_alpha, min=0.0, max=1.0).permute(0, 2, 3, 1) #torch.Size([5, 1024, 1024, 1])
769
+
770
+ decoded_fg = (decoded_fg + 1.0) / 2.0
771
+ decoded_fg = torch.clamp(decoded_fg, min=0.0, max=1.0).permute(0, 2, 3, 1)#torch.Size([5, 1024, 1024, 3]))
772
+
773
+ vis_list = None
774
+ png = torch.cat([decoded_fg, decoded_alpha], dim=3)#[0] #torch.Size([1024, 1024, 4])
775
+ result_list = (png * 255.0).detach().cpu().float().numpy().clip(0, 255).astype(np.uint8)
776
+ else:
777
+ result_list, vis_list = None, None
778
+ image = self.image_processor.postprocess(image, output_type=output_type)
779
+
780
+ # Offload all models
781
+ self.maybe_free_model_hooks()
782
+
783
+ if not return_dict:
784
+ return (image, result_list, vis_list, latents)
785
+
786
+ return FluxPipelineOutput(images=image), result_list, vis_list, latents
787
+
788
+
789
+ class CustomFluxPipelineCfgInpaint(CustomFluxPipeline):
790
 
791
  @torch.no_grad()
792
  def __call__(
793
  self,
794
+ image: Optional[List[torch.FloatTensor]] = None,
795
+ mask: Optional[torch.FloatTensor] = None,
796
  prompt: Union[str, List[str]] = None,
797
  prompt_2: Optional[Union[str, List[str]]] = None,
798
  validation_box: List[tuple] = None,
 
814
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
815
  max_sequence_length: int = 512,
816
  num_layers: int = 5,
817
+ sdxl_vae: nn.Module = None,
818
  transparent_decoder: nn.Module = None,
819
  ):
820
  r"""
 
961
  latents,
962
  )
963
 
964
+ # 4.1. Prepare image and mask
965
+ merged_pt, backgd_pt, list_layer_pt = image[0], image[1], image[2:]
966
+ # prepare RGB, Alpha
967
+ layer_pt_grey = [layer_pt[:, :3] * ((layer_pt[:, 3:4] + 1) / 2.) for layer_pt in list_layer_pt]
968
+ pixel_values_vae_input = torch.cat([merged_pt, backgd_pt] + layer_pt_grey, dim=0).to(device, dtype=self.vae.dtype) # [bs*(l+2), c_img, H, W]
969
+ # Convert images to latent space
970
+ model_input = self.vae.encode(pixel_values_vae_input).latent_dist.sample()
971
+ model_input = (model_input - self.vae.config.shift_factor) * self.vae.config.scaling_factor
972
+ model_input = model_input.reshape(1, len(validation_box), model_input.shape[1], model_input.shape[2], model_input.shape[3]) # [bs, f, c_latent, h, w]
973
+ # copy latent and noise
974
+ orig_latents = model_input
975
+ noise = latents.clone()
976
+
977
  # 5. Prepare timesteps
978
  sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
979
+ image_seq_len = latent_image_ids.shape[0] # ???
980
  mu = calculate_shift(
981
  image_seq_len,
982
  self.scheduler.config.base_image_seq_len,
 
1043
  latents_dtype = latents.dtype
1044
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1045
 
1046
+ # blend the latents with the original image
1047
+ init_latents_proper = orig_latents.to(latents.dtype)
1048
+ init_mask = mask.reshape(1, -1, 1, 1, 1).to(latents.dtype)
1049
+ if i < len(timesteps) - 1:
1050
+ noise_timestep = timesteps[i + 1]
1051
+ init_latents_proper = self.scheduler.scale_noise(
1052
+ init_latents_proper, torch.tensor([noise_timestep]), noise
1053
+ )
1054
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1055
+
1056
  if latents.dtype != latents_dtype:
1057
  if torch.backends.mps.is_available():
1058
  # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
 
1075
  xm.mark_step()
1076
 
1077
  # create a grey latent
1078
+ bs, n_frames, channel_latent, height, width = latents.shape
1079
 
1080
+ pixel_grey = torch.zeros(size=(bs*n_frames, 3, height*8, width*8), device=latents.device, dtype=latents.dtype)
1081
  latent_grey = self.vae.encode(pixel_grey).latent_dist.sample()
1082
  latent_grey = (latent_grey - self.vae.config.shift_factor) * self.vae.config.scaling_factor
1083
+ latent_grey = latent_grey.view(bs, n_frames, channel_latent, height, width) # [bs, f, c_latent, h, w]
1084
 
1085
  # fill in the latents
1086
  for layer_idx in range(latent_grey.shape[1]):
 
1096
 
1097
  else:
1098
  latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1099
+ latents = latents.reshape(bs * n_frames, channel_latent, height, width)
1100
  latents_segs = torch.split(latents, 16, dim=0) ### split latents by 16 to avoid odd purple output
1101
  image_segs = [self.vae.decode(latents_seg, return_dict=False)[0] for latents_seg in latents_segs]
1102
  image = torch.cat(image_segs, dim=0)
1103
+ if sdxl_vae is not None:
1104
+ sdxl_vae = sdxl_vae.to(dtype=image.dtype, device=image.device)
1105
 
1106
+ decoded_fg, decoded_alpha = sdxl_vae(latents, [validation_box])
1107
+ decoded_alpha = (decoded_alpha + 1.0) / 2.0 #torch.Size([5, 1, 1024, 1024])
1108
+ decoded_alpha = torch.clamp(decoded_alpha, min=0.0, max=1.0).permute(0, 2, 3, 1) #torch.Size([5, 1024, 1024, 1])
1109
 
1110
  decoded_fg = (decoded_fg + 1.0) / 2.0
1111
+ decoded_fg = torch.clamp(decoded_fg, min=0.0, max=1.0).permute(0, 2, 3, 1)#torch.Size([5, 1024, 1024, 3]))
1112
 
1113
  vis_list = None
1114
+ png = torch.cat([decoded_fg, decoded_alpha], dim=3)#[0] #torch.Size([1024, 1024, 4])
1115
  result_list = (png * 255.0).detach().cpu().float().numpy().clip(0, 255).astype(np.uint8)
1116
  else:
1117
  result_list, vis_list = None, None
requirements.txt CHANGED
@@ -43,4 +43,5 @@ pynvml==11.5.3 # 新增明确版本(conda实际安装11.5.3)
43
  colorama==0.4.6 # 新增明确版本(conda实际安装0.4.6)
44
  click>=8.0.4,<9 # 保持约束(conda实际安装8.1.7符合要求)\
45
 
46
- sentencepiece
 
 
43
  colorama==0.4.6 # 新增明确版本(conda实际安装0.4.6)
44
  click>=8.0.4,<9 # 保持约束(conda实际安装8.1.7符合要求)\
45
 
46
+ sentencepiece
47
+ random