Update pipeline_bria_controlnet.py
Browse files
pipeline_bria_controlnet.py
CHANGED
@@ -25,9 +25,9 @@ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
|
25 |
from diffusers.schedulers import KarrasDiffusionSchedulers
|
26 |
from diffusers.utils import logging
|
27 |
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
28 |
-
from
|
29 |
from diffusers.pipelines.flux.pipeline_flux import retrieve_timesteps, calculate_shift
|
30 |
-
from
|
31 |
from transformer_bria import BriaTransformer2DModel
|
32 |
from bria_utils import get_original_sigmas
|
33 |
import numpy as np
|
@@ -397,7 +397,10 @@ class BriaControlNetPipeline(BriaPipeline):
|
|
397 |
|
398 |
if isinstance(self.scheduler,FlowMatchEulerDiscreteScheduler) and self.scheduler.config['use_dynamic_shifting']:
|
399 |
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
400 |
-
|
|
|
|
|
|
|
401 |
print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}")
|
402 |
|
403 |
mu = calculate_shift(
|
|
|
25 |
from diffusers.schedulers import KarrasDiffusionSchedulers
|
26 |
from diffusers.utils import logging
|
27 |
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
28 |
+
from controlnet_bria import BriaControlNetModel, BriaMultiControlNetModel
|
29 |
from diffusers.pipelines.flux.pipeline_flux import retrieve_timesteps, calculate_shift
|
30 |
+
from pipeline_bria import BriaPipeline
|
31 |
from transformer_bria import BriaTransformer2DModel
|
32 |
from bria_utils import get_original_sigmas
|
33 |
import numpy as np
|
|
|
397 |
|
398 |
if isinstance(self.scheduler,FlowMatchEulerDiscreteScheduler) and self.scheduler.config['use_dynamic_shifting']:
|
399 |
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
400 |
+
if type(control_image) == list:
|
401 |
+
image_seq_len = control_image[0].shape[1]
|
402 |
+
else:
|
403 |
+
image_seq_len = control_image.shape[1]
|
404 |
print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}")
|
405 |
|
406 |
mu = calculate_shift(
|