aducsdr commited on
Commit
264d081
·
verified ·
1 Parent(s): f1a1847

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -23
app.py CHANGED
@@ -7,17 +7,12 @@ from PIL import Image
7
  from omegaconf import OmegaConf
8
  from huggingface_hub import hf_hub_download
9
 
10
- # --- Início: Bloco de Download Automático do Modelo ---
11
-
12
- # Define o diretório e o caminho para os pesos do modelo
13
  WEIGHTS_DIR = "./pretrained_weights/ByteMorpher"
14
  MODEL_FILENAME = "dit.safetensors"
15
  MODEL_PATH = os.path.join(WEIGHTS_DIR, MODEL_FILENAME)
16
-
17
- # Cria o diretório se ele não existir
18
  os.makedirs(WEIGHTS_DIR, exist_ok=True)
19
 
20
- # Verifica se o modelo já existe antes de fazer o download
21
  if not os.path.exists(MODEL_PATH):
22
  print(f"Modelo não encontrado em {MODEL_PATH}. Baixando do Hugging Face Hub...")
23
  try:
@@ -25,17 +20,14 @@ if not os.path.exists(MODEL_PATH):
25
  repo_id="ByteDance-Seed/BM-Model",
26
  filename=MODEL_FILENAME,
27
  local_dir=WEIGHTS_DIR,
28
- local_dir_use_symlinks=False # Recomendado para Hugging Face Spaces
29
  )
30
  print("Download do modelo concluído com sucesso.")
31
  except Exception as e:
32
  print(f"Ocorreu um erro durante o download do modelo: {e}")
33
- # Se o download falhar, o aplicativo não poderá funcionar.
34
- # Você pode adicionar um tratamento de erro mais robusto aqui se desejar.
35
  else:
36
  print(f"Modelo já existe em {MODEL_PATH}. Pulando o download.")
37
-
38
- # --- Fim: Bloco de Download Automático do Modelo ---
39
 
40
 
41
  from image_datasets.dataset import image_resize
@@ -50,16 +42,15 @@ def generate(image: Image.Image, edit_prompt: str):
50
  from src.flux.xflux_pipeline import XFluxSampler
51
 
52
  global sampler
53
- if sampler == None:
54
- # A inicialização do sampler agora ocorrerá após a confirmação de que o modelo foi baixado.
 
55
  sampler = XFluxSampler(
56
- device = device,
57
- ip_loaded=False,
58
- spatial_condition=False,
59
- clip_image_processor=None,
60
- image_encoder=None,
61
- improj=None,
62
- share_position_embedding = True,
63
  )
64
 
65
  img = image_resize(image, 544)
@@ -68,6 +59,9 @@ def generate(image: Image.Image, edit_prompt: str):
68
  img = torch.from_numpy((np.array(img) / 127.5) - 1)
69
  img = img.permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype)
70
 
 
 
 
71
  result = sampler(
72
  prompt=edit_prompt,
73
  width=args.sample_width,
@@ -75,9 +69,9 @@ def generate(image: Image.Image, edit_prompt: str):
75
  num_steps=args.sample_steps,
76
  image_prompt=None,
77
  true_gs=args.cfg_scale,
78
- seed=args.seed,
79
  ip_scale=args.ip_scale if args.use_ip else 1.0,
80
- source_image=img if args.use_spatial_condition else None,
81
  )
82
  return result
83
 
@@ -201,7 +195,6 @@ def create_app():
201
  </div>
202
  """
203
  )
204
- # gr.Markdown(header, elem_id="header")
205
  with gr.Row(equal_height=False):
206
  with gr.Column(variant="panel", elem_classes="inputPanel"):
207
  original_image = gr.Image(
 
7
  from omegaconf import OmegaConf
8
  from huggingface_hub import hf_hub_download
9
 
10
+ # --- Bloco de Download Automático do Modelo ---
 
 
11
  WEIGHTS_DIR = "./pretrained_weights/ByteMorpher"
12
  MODEL_FILENAME = "dit.safetensors"
13
  MODEL_PATH = os.path.join(WEIGHTS_DIR, MODEL_FILENAME)
 
 
14
  os.makedirs(WEIGHTS_DIR, exist_ok=True)
15
 
 
16
  if not os.path.exists(MODEL_PATH):
17
  print(f"Modelo não encontrado em {MODEL_PATH}. Baixando do Hugging Face Hub...")
18
  try:
 
20
  repo_id="ByteDance-Seed/BM-Model",
21
  filename=MODEL_FILENAME,
22
  local_dir=WEIGHTS_DIR,
23
+ local_dir_use_symlinks=False
24
  )
25
  print("Download do modelo concluído com sucesso.")
26
  except Exception as e:
27
  print(f"Ocorreu um erro durante o download do modelo: {e}")
 
 
28
  else:
29
  print(f"Modelo já existe em {MODEL_PATH}. Pulando o download.")
30
+ # --- Fim do Bloco de Download ---
 
31
 
32
 
33
  from image_datasets.dataset import image_resize
 
42
  from src.flux.xflux_pipeline import XFluxSampler
43
 
44
  global sampler
45
+ if sampler is None:
46
+ # CORREÇÃO: Inicializa o sampler usando os argumentos do arquivo .yaml
47
+ print("Inicializando o XFluxSampler com a configuração...")
48
  sampler = XFluxSampler(
49
+ device=device,
50
+ ip_loaded=args.use_ip,
51
+ spatial_condition=args.use_spatial_condition,
52
+ share_position_embedding=args.share_position_embedding,
53
+ use_share_weight_referencenet=args.use_share_weight_referencenet
 
 
54
  )
55
 
56
  img = image_resize(image, 544)
 
59
  img = torch.from_numpy((np.array(img) / 127.5) - 1)
60
  img = img.permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype)
61
 
62
+ # CORREÇÃO: Passa a imagem de origem se qualquer modo de condicionamento estiver ativo
63
+ use_image_conditioning = args.use_spatial_condition or args.use_share_weight_referencenet
64
+
65
  result = sampler(
66
  prompt=edit_prompt,
67
  width=args.sample_width,
 
69
  num_steps=args.sample_steps,
70
  image_prompt=None,
71
  true_gs=args.cfg_scale,
72
+ seed=args.seed if args.seed != -1 else np.random.randint(0, 2**32 - 1),
73
  ip_scale=args.ip_scale if args.use_ip else 1.0,
74
+ source_image=img if use_image_conditioning else None,
75
  )
76
  return result
77
 
 
195
  </div>
196
  """
197
  )
 
198
  with gr.Row(equal_height=False):
199
  with gr.Column(variant="panel", elem_classes="inputPanel"):
200
  original_image = gr.Image(