fffiloni commited on
Commit
40d9091
1 Parent(s): 1f2d931

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -10
app.py CHANGED
@@ -182,7 +182,9 @@ def infer(ref_style_file, style_description, caption, progress):
182
  lam_style=1, lam_txt_alignment=1.0,
183
  use_ddim_sampler=True,
184
  )
185
- for (sampled_c, _, _) in enumerate(tqdm(sampling_c, total=extras.sampling_configs['timesteps']), 1):
 
 
186
  sampled_c = sampled_c
187
 
188
  progress(0.7, "Starting Stage B reverse process")
@@ -195,7 +197,9 @@ def infer(ref_style_file, style_description, caption, progress):
195
  models_b.generator, conditions_b, stage_b_latent_shape,
196
  unconditions_b, device=device, **extras_b.sampling_configs,
197
  )
198
- for (sampled_b, _, _) in enumerate(tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']), 1):
 
 
199
  sampled_b = sampled_b
200
  sampled = models_b.stage_a.decode(sampled_b).float()
201
 
@@ -222,7 +226,7 @@ def infer(ref_style_file, style_description, caption, progress):
222
  # Clear CUDA cache
223
  torch.cuda.empty_cache()
224
 
225
- def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
226
  global models_rbm, models_b, device, sam_model
227
  if low_vram:
228
  models_to(models_rbm, device=device)
@@ -246,13 +250,15 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
246
  extras_b.sampling_configs['timesteps'] = 10
247
  extras_b.sampling_configs['t_start'] = 1.0
248
 
 
249
  ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
250
  ref_images = resize_image(PIL.Image.open(ref_sub_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
251
 
252
  batch = {'captions': [caption] * batch_size}
253
  batch['style'] = ref_style
254
  batch['images'] = ref_images
255
-
 
256
  x0_forward = models_rbm.effnet(extras.effnet_preprocess(ref_images.to(device)))
257
  x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
258
 
@@ -264,7 +270,8 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
264
  sam_mask, boxes, phrases, logits = sam_model.predict(x0_preview_pil, sam_prompt)
265
  # sam_mask, boxes, phrases, logits = sam_model.predict(transform(x0_preview[0]), sam_prompt)
266
  sam_mask = sam_mask.detach().unsqueeze(dim=0).to(device)
267
-
 
268
  conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_subject_style=True, eval_csd=False)
269
  unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False, eval_subject_style=True)
270
  conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
@@ -274,7 +281,8 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
274
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
275
  models_to(sam_model, device="cpu")
276
  models_to(sam_model.sam, device="cpu")
277
-
 
278
  # Stage C reverse process.
279
  sampling_c = extras.gdf.sample(
280
  models_rbm.generator, conditions, stage_c_latent_shape,
@@ -291,9 +299,12 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
291
  sam_prompt=sam_prompt
292
  )
293
 
294
- for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
 
 
295
  sampled_c = sampled_c
296
 
 
297
  # Stage B reverse process.
298
  with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
299
  conditions_b['effnet'] = sampled_c
@@ -303,10 +314,13 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
303
  models_b.generator, conditions_b, stage_b_latent_shape,
304
  unconditions_b, device=device, **extras_b.sampling_configs,
305
  )
306
- for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
 
 
307
  sampled_b = sampled_b
308
  sampled = models_b.stage_a.decode(sampled_b).float()
309
 
 
310
  sampled = torch.cat([
311
  torch.nn.functional.interpolate(ref_images.cpu(), size=(height, width)),
312
  torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
@@ -332,9 +346,9 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
332
 
333
  def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref):
334
  result = None
335
- progress = gr.Progress(track_tqdm=True)
336
  if use_subject_ref is True:
337
- result = infer_compo(style_description, style_reference_image, subject_prompt, subject_reference)
338
  else:
339
  result = infer(style_reference_image, style_description, subject_prompt, progress)
340
  return result
 
182
  lam_style=1, lam_txt_alignment=1.0,
183
  use_ddim_sampler=True,
184
  )
185
+ for i, (sampled_c, _, _) in enumerate(sampling_c, 1):
186
+ if i % 5 == 0: # Update progress every 5 steps
187
+ progress(0.4 + 0.3 * (i / extras.sampling_configs['timesteps']), f"Stage C reverse process: step {i}/{extras.sampling_configs['timesteps']}")
188
  sampled_c = sampled_c
189
 
190
  progress(0.7, "Starting Stage B reverse process")
 
197
  models_b.generator, conditions_b, stage_b_latent_shape,
198
  unconditions_b, device=device, **extras_b.sampling_configs,
199
  )
200
+ for i, (sampled_b, _, _) in enumerate(sampling_b, 1):
201
+ if i % 5 == 0: # Update progress every 5 steps
202
+ progress(0.7 + 0.2 * (i / extras_b.sampling_configs['timesteps']), f"Stage B reverse process: step {i}/{extras_b.sampling_configs['timesteps']}")
203
  sampled_b = sampled_b
204
  sampled = models_b.stage_a.decode(sampled_b).float()
205
 
 
226
  # Clear CUDA cache
227
  torch.cuda.empty_cache()
228
 
229
+ def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progress):
230
  global models_rbm, models_b, device, sam_model
231
  if low_vram:
232
  models_to(models_rbm, device=device)
 
250
  extras_b.sampling_configs['timesteps'] = 10
251
  extras_b.sampling_configs['t_start'] = 1.0
252
 
253
+ progress(0.1, "Loading style and subject reference images")
254
  ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
255
  ref_images = resize_image(PIL.Image.open(ref_sub_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
256
 
257
  batch = {'captions': [caption] * batch_size}
258
  batch['style'] = ref_style
259
  batch['images'] = ref_images
260
+
261
+ progress(0.2, "Processing reference images")
262
  x0_forward = models_rbm.effnet(extras.effnet_preprocess(ref_images.to(device)))
263
  x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
264
 
 
270
  sam_mask, boxes, phrases, logits = sam_model.predict(x0_preview_pil, sam_prompt)
271
  # sam_mask, boxes, phrases, logits = sam_model.predict(transform(x0_preview[0]), sam_prompt)
272
  sam_mask = sam_mask.detach().unsqueeze(dim=0).to(device)
273
+
274
+ progress(0.3, "Generating conditions")
275
  conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_subject_style=True, eval_csd=False)
276
  unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False, eval_subject_style=True)
277
  conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
 
281
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
282
  models_to(sam_model, device="cpu")
283
  models_to(sam_model.sam, device="cpu")
284
+
285
+ progress(0.4, "Starting Stage C reverse process")
286
  # Stage C reverse process.
287
  sampling_c = extras.gdf.sample(
288
  models_rbm.generator, conditions, stage_c_latent_shape,
 
299
  sam_prompt=sam_prompt
300
  )
301
 
302
+ for i, (sampled_c, _, _) in enumerate(sampling_c, 1):
303
+ if i % 5 == 0: # Update progress every 5 steps
304
+ progress(0.4 + 0.3 * (i / extras.sampling_configs['timesteps']), f"Stage C reverse process: step {i}/{extras.sampling_configs['timesteps']}")
305
  sampled_c = sampled_c
306
 
307
+ progress(0.7, "Starting Stage B reverse process")
308
  # Stage B reverse process.
309
  with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
310
  conditions_b['effnet'] = sampled_c
 
314
  models_b.generator, conditions_b, stage_b_latent_shape,
315
  unconditions_b, device=device, **extras_b.sampling_configs,
316
  )
317
+ for i, (sampled_b, _, _) in enumerate(sampling_b, 1):
318
+ if i % 5 == 0: # Update progress every 5 steps
319
+ progress(0.7 + 0.2 * (i / extras_b.sampling_configs['timesteps']), f"Stage B reverse process: step {i}/{extras_b.sampling_configs['timesteps']}")
320
  sampled_b = sampled_b
321
  sampled = models_b.stage_a.decode(sampled_b).float()
322
 
323
+ progress(0.9, "Finalizing the output image")
324
  sampled = torch.cat([
325
  torch.nn.functional.interpolate(ref_images.cpu(), size=(height, width)),
326
  torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
 
346
 
347
  def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref):
348
  result = None
349
+ progress = gr.Progress()
350
  if use_subject_ref is True:
351
+ result = infer_compo(style_description, style_reference_image, subject_prompt, subject_reference, progress)
352
  else:
353
  result = infer(style_reference_image, style_description, subject_prompt, progress)
354
  return result