RageshAntony commited on
Commit
99d1063
·
verified ·
1 Parent(s): 8f2de47

fixed callback

Browse files
Files changed (1) hide show
  1. check_app.py +6 -6
check_app.py CHANGED
@@ -15,17 +15,17 @@ def generate_image_with_progress(pipe, prompt, num_steps, guidance_scale=None, s
15
  generator = torch.Generator("cpu").manual_seed(seed)
16
  print("Start generating")
17
  # Wrapper to track progress
18
- def callback(step, timestep, latents):
19
- cur_prg = step / num_steps
20
- print(f"Progressing {cur_prg} ")
21
- progress(cur_prg, desc=f"Step {step}/{num_steps}")
22
 
23
  if isinstance(pipe, StableDiffusion3Pipeline):
24
  image = pipe(
25
  prompt,
26
  num_inference_steps=num_steps,
27
  guidance_scale=guidance_scale,
28
- callback=callback,
29
  ).images[0]
30
  elif isinstance(pipe, FluxPipeline):
31
  image = pipe(
@@ -33,7 +33,7 @@ def generate_image_with_progress(pipe, prompt, num_steps, guidance_scale=None, s
33
  num_inference_steps=num_steps,
34
  generator=generator,
35
  output_type="pil",
36
- callback=callback,
37
  ).images[0]
38
  return image
39
 
 
15
  generator = torch.Generator("cpu").manual_seed(seed)
16
  print("Start generating")
17
  # Wrapper to track progress
18
+ def callback(pipe, step_index, timestep, callback_kwargs): # pipe, step_index, timestep, callback_kwargs
19
+ cur_prg = step_index / num_steps
20
+ print(f"Progressing {cur_prg} Step {step_index}/{num_steps}")
21
+ progress(cur_prg, desc=f"Step {step_index}/{num_steps}")
22
 
23
  if isinstance(pipe, StableDiffusion3Pipeline):
24
  image = pipe(
25
  prompt,
26
  num_inference_steps=num_steps,
27
  guidance_scale=guidance_scale,
28
+ callback_on_step_end=callback,
29
  ).images[0]
30
  elif isinstance(pipe, FluxPipeline):
31
  image = pipe(
 
33
  num_inference_steps=num_steps,
34
  generator=generator,
35
  output_type="pil",
36
+ callback_on_step_end=callback,
37
  ).images[0]
38
  return image
39