RageshAntony commited on
Commit
9dab390
·
verified ·
1 Parent(s): 40e3650

added ProgressAuraFlowPipeline

Browse files
Files changed (1) hide show
  1. check_app.py +54 -0
check_app.py CHANGED
@@ -12,6 +12,60 @@ from diffusers import (
12
  LuminaText2ImgPipeline,AutoPipelineForText2Image
13
  )
14
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  cache_dir = '/workspace/hf_cache'
17
 
 
12
  LuminaText2ImgPipeline,AutoPipelineForText2Image
13
  )
14
  import gradio as gr
15
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
16
+
17
+ class ProgressAuraFlowPipeline(DiffusionPipeline):
18
+ def __init__(self, original_pipeline):
19
+ super().__init__()
20
+ self.original_pipeline = original_pipeline
21
+ # Register all components from the original pipeline
22
+ for attr_name, attr_value in vars(original_pipeline).items():
23
+ setattr(self, attr_name, attr_value)
24
+
25
+ @torch.no_grad()
26
+ def __call__(
27
+ self,
28
+ prompt,
29
+ num_inference_steps=30,
30
+ generator=None,
31
+ guidance_scale=7.5,
32
+ callback=None,
33
+ callback_steps=1,
34
+ **kwargs
35
+ ):
36
+ # Initialize the progress tracking
37
+ self._num_inference_steps = num_inference_steps
38
+ self._step = 0
39
+
40
+ def progress_callback(pipe, step_index, timestep, callback_kwargs):
41
+ if callback and step_index % callback_steps == 0:
42
+ callback(step_index, timestep, callback_kwargs)
43
+ return callback_kwargs
44
+
45
+ # Monkey patch the original pipeline's progress tracking
46
+ original_step = self.original_pipeline.scheduler.step
47
+ def wrapped_step(*args, **kwargs):
48
+ self._step += 1
49
+ if callback:
50
+ progress_callback(self, self._step, None, {})
51
+ return original_step(*args, **kwargs)
52
+
53
+ self.original_pipeline.scheduler.step = wrapped_step
54
+
55
+ try:
56
+ # Call the original pipeline
57
+ result = self.original_pipeline(
58
+ prompt=prompt,
59
+ num_inference_steps=num_inference_steps,
60
+ generator=generator,
61
+ guidance_scale=guidance_scale,
62
+ **kwargs
63
+ )
64
+
65
+ return result
66
+ finally:
67
+ # Restore the original step function
68
+ self.original_pipeline.scheduler.step = original_step
69
 
70
  cache_dir = '/workspace/hf_cache'
71