Diffusers
TalHach61 commited on
Commit
81941e2
·
verified ·
1 Parent(s): f643374

Update pipeline_bria_controlnet.py

Browse files
Files changed (1) hide show
  1. pipeline_bria_controlnet.py +6 -3
pipeline_bria_controlnet.py CHANGED
@@ -25,9 +25,9 @@ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
25
  from diffusers.schedulers import KarrasDiffusionSchedulers
26
  from diffusers.utils import logging
27
  from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
28
- from .controlnet_bria import BriaControlNetModel, BriaMultiControlNetModel
29
  from diffusers.pipelines.flux.pipeline_flux import retrieve_timesteps, calculate_shift
30
- from .pipeline_bria import BriaPipeline
31
  from transformer_bria import BriaTransformer2DModel
32
  from bria_utils import get_original_sigmas
33
  import numpy as np
@@ -397,7 +397,10 @@ class BriaControlNetPipeline(BriaPipeline):
397
 
398
  if isinstance(self.scheduler,FlowMatchEulerDiscreteScheduler) and self.scheduler.config['use_dynamic_shifting']:
399
  sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
400
- image_seq_len = control_image.shape[1]
 
 
 
401
  print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}")
402
 
403
  mu = calculate_shift(
 
25
  from diffusers.schedulers import KarrasDiffusionSchedulers
26
  from diffusers.utils import logging
27
  from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
28
+ from controlnet_bria import BriaControlNetModel, BriaMultiControlNetModel
29
  from diffusers.pipelines.flux.pipeline_flux import retrieve_timesteps, calculate_shift
30
+ from pipeline_bria import BriaPipeline
31
  from transformer_bria import BriaTransformer2DModel
32
  from bria_utils import get_original_sigmas
33
  import numpy as np
 
397
 
398
  if isinstance(self.scheduler,FlowMatchEulerDiscreteScheduler) and self.scheduler.config['use_dynamic_shifting']:
399
  sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
400
+ if type(control_image) == list:
401
+ image_seq_len = control_image[0].shape[1]
402
+ else:
403
+ image_seq_len = control_image.shape[1]
404
  print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}")
405
 
406
  mu = calculate_shift(