|  | import inspect | 
					
						
						|  | from typing import Any, Dict, List, Optional, Union | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | from transformers import AutoModel, AutoTokenizer, CLIPImageProcessor | 
					
						
						|  |  | 
					
						
						|  | from diffusers import DiffusionPipeline | 
					
						
						|  | from diffusers.image_processor import VaeImageProcessor | 
					
						
						|  | from diffusers.loaders import LoraLoaderMixin | 
					
						
						|  | from diffusers.models import AutoencoderKL, UNet2DConditionModel | 
					
						
						|  | from diffusers.models.lora import adjust_lora_scale_text_encoder | 
					
						
						|  | from diffusers.pipelines.pipeline_utils import StableDiffusionMixin | 
					
						
						|  | from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput | 
					
						
						|  | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker | 
					
						
						|  | from diffusers.schedulers import KarrasDiffusionSchedulers | 
					
						
						|  | from diffusers.utils import ( | 
					
						
						|  | USE_PEFT_BACKEND, | 
					
						
						|  | logging, | 
					
						
						|  | scale_lora_layers, | 
					
						
						|  | unscale_lora_layers, | 
					
						
						|  | ) | 
					
						
						|  | from diffusers.utils.torch_utils import randn_tensor | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logger = logging.get_logger(__name__) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TranslatorBase(nn.Module): | 
					
						
						|  | def __init__(self, num_tok, dim, dim_out, mult=2): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.dim_in = dim | 
					
						
						|  | self.dim_out = dim_out | 
					
						
						|  |  | 
					
						
						|  | self.net_tok = nn.Sequential( | 
					
						
						|  | nn.Linear(num_tok, int(num_tok * mult)), | 
					
						
						|  | nn.LayerNorm(int(num_tok * mult)), | 
					
						
						|  | nn.GELU(), | 
					
						
						|  | nn.Linear(int(num_tok * mult), int(num_tok * mult)), | 
					
						
						|  | nn.LayerNorm(int(num_tok * mult)), | 
					
						
						|  | nn.GELU(), | 
					
						
						|  | nn.Linear(int(num_tok * mult), num_tok), | 
					
						
						|  | nn.LayerNorm(num_tok), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.net_sen = nn.Sequential( | 
					
						
						|  | nn.Linear(dim, int(dim * mult)), | 
					
						
						|  | nn.LayerNorm(int(dim * mult)), | 
					
						
						|  | nn.GELU(), | 
					
						
						|  | nn.Linear(int(dim * mult), int(dim * mult)), | 
					
						
						|  | nn.LayerNorm(int(dim * mult)), | 
					
						
						|  | nn.GELU(), | 
					
						
						|  | nn.Linear(int(dim * mult), dim_out), | 
					
						
						|  | nn.LayerNorm(dim_out), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | if self.dim_in == self.dim_out: | 
					
						
						|  | indentity_0 = x | 
					
						
						|  | x = self.net_sen(x) | 
					
						
						|  | x += indentity_0 | 
					
						
						|  | x = x.transpose(1, 2) | 
					
						
						|  |  | 
					
						
						|  | indentity_1 = x | 
					
						
						|  | x = self.net_tok(x) | 
					
						
						|  | x += indentity_1 | 
					
						
						|  | x = x.transpose(1, 2) | 
					
						
						|  | else: | 
					
						
						|  | x = self.net_sen(x) | 
					
						
						|  | x = x.transpose(1, 2) | 
					
						
						|  |  | 
					
						
						|  | x = self.net_tok(x) | 
					
						
						|  | x = x.transpose(1, 2) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TranslatorBaseNoLN(nn.Module): | 
					
						
						|  | def __init__(self, num_tok, dim, dim_out, mult=2): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.dim_in = dim | 
					
						
						|  | self.dim_out = dim_out | 
					
						
						|  |  | 
					
						
						|  | self.net_tok = nn.Sequential( | 
					
						
						|  | nn.Linear(num_tok, int(num_tok * mult)), | 
					
						
						|  | nn.GELU(), | 
					
						
						|  | nn.Linear(int(num_tok * mult), int(num_tok * mult)), | 
					
						
						|  | nn.GELU(), | 
					
						
						|  | nn.Linear(int(num_tok * mult), num_tok), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.net_sen = nn.Sequential( | 
					
						
						|  | nn.Linear(dim, int(dim * mult)), | 
					
						
						|  | nn.GELU(), | 
					
						
						|  | nn.Linear(int(dim * mult), int(dim * mult)), | 
					
						
						|  | nn.GELU(), | 
					
						
						|  | nn.Linear(int(dim * mult), dim_out), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | if self.dim_in == self.dim_out: | 
					
						
						|  | indentity_0 = x | 
					
						
						|  | x = self.net_sen(x) | 
					
						
						|  | x += indentity_0 | 
					
						
						|  | x = x.transpose(1, 2) | 
					
						
						|  |  | 
					
						
						|  | indentity_1 = x | 
					
						
						|  | x = self.net_tok(x) | 
					
						
						|  | x += indentity_1 | 
					
						
						|  | x = x.transpose(1, 2) | 
					
						
						|  | else: | 
					
						
						|  | x = self.net_sen(x) | 
					
						
						|  | x = x.transpose(1, 2) | 
					
						
						|  |  | 
					
						
						|  | x = self.net_tok(x) | 
					
						
						|  | x = x.transpose(1, 2) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TranslatorNoLN(nn.Module): | 
					
						
						|  | def __init__(self, num_tok, dim, dim_out, mult=2, depth=5): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.blocks = nn.ModuleList([TranslatorBase(num_tok, dim, dim, mult=2) for d in range(depth)]) | 
					
						
						|  | self.gelu = nn.GELU() | 
					
						
						|  |  | 
					
						
						|  | self.tail = TranslatorBaseNoLN(num_tok, dim, dim_out, mult=2) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | for block in self.blocks: | 
					
						
						|  | x = block(x) + x | 
					
						
						|  | x = self.gelu(x) | 
					
						
						|  |  | 
					
						
						|  | x = self.tail(x) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): | 
					
						
						|  | """ | 
					
						
						|  | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and | 
					
						
						|  | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 | 
					
						
						|  | """ | 
					
						
						|  | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) | 
					
						
						|  | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) | 
					
						
						|  |  | 
					
						
						|  | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) | 
					
						
						|  |  | 
					
						
						|  | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg | 
					
						
						|  | return noise_cfg | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def retrieve_timesteps( | 
					
						
						|  | scheduler, | 
					
						
						|  | num_inference_steps: Optional[int] = None, | 
					
						
						|  | device: Optional[Union[str, torch.device]] = None, | 
					
						
						|  | timesteps: Optional[List[int]] = None, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles | 
					
						
						|  | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | scheduler (`SchedulerMixin`): | 
					
						
						|  | The scheduler to get timesteps from. | 
					
						
						|  | num_inference_steps (`int`): | 
					
						
						|  | The number of diffusion steps used when generating samples with a pre-trained model. If used, | 
					
						
						|  | `timesteps` must be `None`. | 
					
						
						|  | device (`str` or `torch.device`, *optional*): | 
					
						
						|  | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | 
					
						
						|  | timesteps (`List[int]`, *optional*): | 
					
						
						|  | Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default | 
					
						
						|  | timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` | 
					
						
						|  | must be `None`. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the | 
					
						
						|  | second element is the number of inference steps. | 
					
						
						|  | """ | 
					
						
						|  | if timesteps is not None: | 
					
						
						|  | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | 
					
						
						|  | if not accepts_timesteps: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | 
					
						
						|  | f" timestep schedules. Please check whether you are using the correct scheduler." | 
					
						
						|  | ) | 
					
						
						|  | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | 
					
						
						|  | timesteps = scheduler.timesteps | 
					
						
						|  | num_inference_steps = len(timesteps) | 
					
						
						|  | else: | 
					
						
						|  | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | 
					
						
						|  | timesteps = scheduler.timesteps | 
					
						
						|  | return timesteps, num_inference_steps | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class GlueGenStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin, LoraLoaderMixin): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | vae: AutoencoderKL, | 
					
						
						|  | text_encoder: AutoModel, | 
					
						
						|  | tokenizer: AutoTokenizer, | 
					
						
						|  | unet: UNet2DConditionModel, | 
					
						
						|  | scheduler: KarrasDiffusionSchedulers, | 
					
						
						|  | safety_checker: StableDiffusionSafetyChecker, | 
					
						
						|  | feature_extractor: CLIPImageProcessor, | 
					
						
						|  | language_adapter: TranslatorNoLN = None, | 
					
						
						|  | tensor_norm: torch.Tensor = None, | 
					
						
						|  | requires_safety_checker: bool = True, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.register_modules( | 
					
						
						|  | vae=vae, | 
					
						
						|  | text_encoder=text_encoder, | 
					
						
						|  | tokenizer=tokenizer, | 
					
						
						|  | unet=unet, | 
					
						
						|  | scheduler=scheduler, | 
					
						
						|  | safety_checker=safety_checker, | 
					
						
						|  | feature_extractor=feature_extractor, | 
					
						
						|  | language_adapter=language_adapter, | 
					
						
						|  | tensor_norm=tensor_norm, | 
					
						
						|  | ) | 
					
						
						|  | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | 
					
						
						|  | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) | 
					
						
						|  | self.register_to_config(requires_safety_checker=requires_safety_checker) | 
					
						
						|  |  | 
					
						
						|  | def load_language_adapter( | 
					
						
						|  | self, | 
					
						
						|  | model_path: str, | 
					
						
						|  | num_token: int, | 
					
						
						|  | dim: int, | 
					
						
						|  | dim_out: int, | 
					
						
						|  | tensor_norm: torch.Tensor, | 
					
						
						|  | mult: int = 2, | 
					
						
						|  | depth: int = 5, | 
					
						
						|  | ): | 
					
						
						|  | device = self._execution_device | 
					
						
						|  | self.tensor_norm = tensor_norm.to(device) | 
					
						
						|  | self.language_adapter = TranslatorNoLN(num_tok=num_token, dim=dim, dim_out=dim_out, mult=mult, depth=depth).to( | 
					
						
						|  | device | 
					
						
						|  | ) | 
					
						
						|  | self.language_adapter.load_state_dict(torch.load(model_path)) | 
					
						
						|  |  | 
					
						
						|  | def _adapt_language(self, prompt_embeds: torch.Tensor): | 
					
						
						|  | prompt_embeds = prompt_embeds / 3 | 
					
						
						|  | prompt_embeds = self.language_adapter(prompt_embeds) * (self.tensor_norm / 2) | 
					
						
						|  | return prompt_embeds | 
					
						
						|  |  | 
					
						
						|  | def encode_prompt( | 
					
						
						|  | self, | 
					
						
						|  | prompt, | 
					
						
						|  | device, | 
					
						
						|  | num_images_per_prompt, | 
					
						
						|  | do_classifier_free_guidance, | 
					
						
						|  | negative_prompt=None, | 
					
						
						|  | prompt_embeds: Optional[torch.Tensor] = None, | 
					
						
						|  | negative_prompt_embeds: Optional[torch.Tensor] = None, | 
					
						
						|  | lora_scale: Optional[float] = None, | 
					
						
						|  | clip_skip: Optional[int] = None, | 
					
						
						|  | ): | 
					
						
						|  | r""" | 
					
						
						|  | Encodes the prompt into text encoder hidden states. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | prompt (`str` or `List[str]`, *optional*): | 
					
						
						|  | prompt to be encoded | 
					
						
						|  | device: (`torch.device`): | 
					
						
						|  | torch device | 
					
						
						|  | num_images_per_prompt (`int`): | 
					
						
						|  | number of images that should be generated per prompt | 
					
						
						|  | do_classifier_free_guidance (`bool`): | 
					
						
						|  | whether to use classifier free guidance or not | 
					
						
						|  | negative_prompt (`str` or `List[str]`, *optional*): | 
					
						
						|  | The prompt or prompts not to guide the image generation. If not defined, one has to pass | 
					
						
						|  | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is | 
					
						
						|  | less than `1`). | 
					
						
						|  | prompt_embeds (`torch.Tensor`, *optional*): | 
					
						
						|  | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | 
					
						
						|  | provided, text embeddings will be generated from `prompt` input argument. | 
					
						
						|  | negative_prompt_embeds (`torch.Tensor`, *optional*): | 
					
						
						|  | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt | 
					
						
						|  | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input | 
					
						
						|  | argument. | 
					
						
						|  | lora_scale (`float`, *optional*): | 
					
						
						|  | A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. | 
					
						
						|  | clip_skip (`int`, *optional*): | 
					
						
						|  | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that | 
					
						
						|  | the output of the pre-final layer will be used for computing the prompt embeddings. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if lora_scale is not None and isinstance(self, LoraLoaderMixin): | 
					
						
						|  | self._lora_scale = lora_scale | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if not USE_PEFT_BACKEND: | 
					
						
						|  | adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) | 
					
						
						|  | else: | 
					
						
						|  | scale_lora_layers(self.text_encoder, lora_scale) | 
					
						
						|  |  | 
					
						
						|  | if prompt is not None and isinstance(prompt, str): | 
					
						
						|  | batch_size = 1 | 
					
						
						|  | elif prompt is not None and isinstance(prompt, list): | 
					
						
						|  | batch_size = len(prompt) | 
					
						
						|  | else: | 
					
						
						|  | batch_size = prompt_embeds.shape[0] | 
					
						
						|  |  | 
					
						
						|  | if prompt_embeds is None: | 
					
						
						|  | text_inputs = self.tokenizer( | 
					
						
						|  | prompt, | 
					
						
						|  | padding="max_length", | 
					
						
						|  | max_length=self.tokenizer.model_max_length, | 
					
						
						|  | truncation=True, | 
					
						
						|  | return_tensors="pt", | 
					
						
						|  | ) | 
					
						
						|  | text_input_ids = text_inputs.input_ids | 
					
						
						|  | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids | 
					
						
						|  |  | 
					
						
						|  | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( | 
					
						
						|  | text_input_ids, untruncated_ids | 
					
						
						|  | ): | 
					
						
						|  | removed_text = self.tokenizer.batch_decode( | 
					
						
						|  | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] | 
					
						
						|  | ) | 
					
						
						|  | logger.warning( | 
					
						
						|  | "The following part of your input was truncated because CLIP can only handle sequences up to" | 
					
						
						|  | f" {self.tokenizer.model_max_length} tokens: {removed_text}" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | 
					
						
						|  | attention_mask = text_inputs.attention_mask.to(device) | 
					
						
						|  | elif self.language_adapter is not None: | 
					
						
						|  | attention_mask = text_inputs.attention_mask.to(device) | 
					
						
						|  | else: | 
					
						
						|  | attention_mask = None | 
					
						
						|  |  | 
					
						
						|  | if clip_skip is None: | 
					
						
						|  | prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) | 
					
						
						|  | prompt_embeds = prompt_embeds[0] | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | prompt_embeds = self.text_encoder( | 
					
						
						|  | text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.language_adapter is not None: | 
					
						
						|  | prompt_embeds = self._adapt_language(prompt_embeds) | 
					
						
						|  |  | 
					
						
						|  | if self.text_encoder is not None: | 
					
						
						|  | prompt_embeds_dtype = self.text_encoder.dtype | 
					
						
						|  | elif self.unet is not None: | 
					
						
						|  | prompt_embeds_dtype = self.unet.dtype | 
					
						
						|  | else: | 
					
						
						|  | prompt_embeds_dtype = prompt_embeds.dtype | 
					
						
						|  |  | 
					
						
						|  | prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) | 
					
						
						|  |  | 
					
						
						|  | bs_embed, seq_len, _ = prompt_embeds.shape | 
					
						
						|  |  | 
					
						
						|  | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | 
					
						
						|  | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if do_classifier_free_guidance and negative_prompt_embeds is None: | 
					
						
						|  | uncond_tokens: List[str] | 
					
						
						|  | if negative_prompt is None: | 
					
						
						|  | uncond_tokens = [""] * batch_size | 
					
						
						|  | elif prompt is not None and type(prompt) is not type(negative_prompt): | 
					
						
						|  | raise TypeError( | 
					
						
						|  | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" | 
					
						
						|  | f" {type(prompt)}." | 
					
						
						|  | ) | 
					
						
						|  | elif isinstance(negative_prompt, str): | 
					
						
						|  | uncond_tokens = [negative_prompt] | 
					
						
						|  | elif batch_size != len(negative_prompt): | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" | 
					
						
						|  | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" | 
					
						
						|  | " the batch size of `prompt`." | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | uncond_tokens = negative_prompt | 
					
						
						|  |  | 
					
						
						|  | max_length = prompt_embeds.shape[1] | 
					
						
						|  | uncond_input = self.tokenizer( | 
					
						
						|  | uncond_tokens, | 
					
						
						|  | padding="max_length", | 
					
						
						|  | max_length=max_length, | 
					
						
						|  | truncation=True, | 
					
						
						|  | return_tensors="pt", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | 
					
						
						|  | attention_mask = uncond_input.attention_mask.to(device) | 
					
						
						|  | else: | 
					
						
						|  | attention_mask = None | 
					
						
						|  |  | 
					
						
						|  | negative_prompt_embeds = self.text_encoder( | 
					
						
						|  | uncond_input.input_ids.to(device), | 
					
						
						|  | attention_mask=attention_mask, | 
					
						
						|  | ) | 
					
						
						|  | negative_prompt_embeds = negative_prompt_embeds[0] | 
					
						
						|  |  | 
					
						
						|  | if self.language_adapter is not None: | 
					
						
						|  | negative_prompt_embeds = self._adapt_language(negative_prompt_embeds) | 
					
						
						|  |  | 
					
						
						|  | if do_classifier_free_guidance: | 
					
						
						|  |  | 
					
						
						|  | seq_len = negative_prompt_embeds.shape[1] | 
					
						
						|  |  | 
					
						
						|  | negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) | 
					
						
						|  |  | 
					
						
						|  | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) | 
					
						
						|  | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | 
					
						
						|  |  | 
					
						
						|  | if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: | 
					
						
						|  |  | 
					
						
						|  | unscale_lora_layers(self.text_encoder, lora_scale) | 
					
						
						|  |  | 
					
						
						|  | return prompt_embeds, negative_prompt_embeds | 
					
						
						|  |  | 
					
						
						|  | def run_safety_checker(self, image, device, dtype): | 
					
						
						|  | if self.safety_checker is None: | 
					
						
						|  | has_nsfw_concept = None | 
					
						
						|  | else: | 
					
						
						|  | if torch.is_tensor(image): | 
					
						
						|  | feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") | 
					
						
						|  | else: | 
					
						
						|  | feature_extractor_input = self.image_processor.numpy_to_pil(image) | 
					
						
						|  | safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) | 
					
						
						|  | image, has_nsfw_concept = self.safety_checker( | 
					
						
						|  | images=image, clip_input=safety_checker_input.pixel_values.to(dtype) | 
					
						
						|  | ) | 
					
						
						|  | return image, has_nsfw_concept | 
					
						
						|  |  | 
					
						
						|  | def prepare_extra_step_kwargs(self, generator, eta): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | 
					
						
						|  | extra_step_kwargs = {} | 
					
						
						|  | if accepts_eta: | 
					
						
						|  | extra_step_kwargs["eta"] = eta | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) | 
					
						
						|  | if accepts_generator: | 
					
						
						|  | extra_step_kwargs["generator"] = generator | 
					
						
						|  | return extra_step_kwargs | 
					
						
						|  |  | 
					
						
						|  | def check_inputs( | 
					
						
						|  | self, | 
					
						
						|  | prompt, | 
					
						
						|  | height, | 
					
						
						|  | width, | 
					
						
						|  | negative_prompt=None, | 
					
						
						|  | prompt_embeds=None, | 
					
						
						|  | negative_prompt_embeds=None, | 
					
						
						|  | ): | 
					
						
						|  | if height % 8 != 0 or width % 8 != 0: | 
					
						
						|  | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | 
					
						
						|  |  | 
					
						
						|  | if prompt is not None and prompt_embeds is not None: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" | 
					
						
						|  | " only forward one of the two." | 
					
						
						|  | ) | 
					
						
						|  | elif prompt is None and prompt_embeds is None: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." | 
					
						
						|  | ) | 
					
						
						|  | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): | 
					
						
						|  | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | 
					
						
						|  |  | 
					
						
						|  | if negative_prompt is not None and negative_prompt_embeds is not None: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" | 
					
						
						|  | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if prompt_embeds is not None and negative_prompt_embeds is not None: | 
					
						
						|  | if prompt_embeds.shape != negative_prompt_embeds.shape: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" | 
					
						
						|  | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" | 
					
						
						|  | f" {negative_prompt_embeds.shape}." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): | 
					
						
						|  | shape = ( | 
					
						
						|  | batch_size, | 
					
						
						|  | num_channels_latents, | 
					
						
						|  | int(height) // self.vae_scale_factor, | 
					
						
						|  | int(width) // self.vae_scale_factor, | 
					
						
						|  | ) | 
					
						
						|  | if isinstance(generator, list) and len(generator) != batch_size: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | 
					
						
						|  | f" size of {batch_size}. Make sure the batch size matches the length of the generators." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if latents is None: | 
					
						
						|  | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | 
					
						
						|  | else: | 
					
						
						|  | latents = latents.to(device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | latents = latents * self.scheduler.init_noise_sigma | 
					
						
						|  | return latents | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): | 
					
						
						|  | """ | 
					
						
						|  | See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | timesteps (`torch.Tensor`): | 
					
						
						|  | generate embedding vectors at these timesteps | 
					
						
						|  | embedding_dim (`int`, *optional*, defaults to 512): | 
					
						
						|  | dimension of the embeddings to generate | 
					
						
						|  | dtype: | 
					
						
						|  | data type of the generated embeddings | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` | 
					
						
						|  | """ | 
					
						
						|  | assert len(w.shape) == 1 | 
					
						
						|  | w = w * 1000.0 | 
					
						
						|  |  | 
					
						
						|  | half_dim = embedding_dim // 2 | 
					
						
						|  | emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) | 
					
						
						|  | emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) | 
					
						
						|  | emb = w.to(dtype)[:, None] * emb[None, :] | 
					
						
						|  | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) | 
					
						
						|  | if embedding_dim % 2 == 1: | 
					
						
						|  | emb = torch.nn.functional.pad(emb, (0, 1)) | 
					
						
						|  | assert emb.shape == (w.shape[0], embedding_dim) | 
					
						
						|  | return emb | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def guidance_scale(self): | 
					
						
						|  | return self._guidance_scale | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def guidance_rescale(self): | 
					
						
						|  | return self._guidance_rescale | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def clip_skip(self): | 
					
						
						|  | return self._clip_skip | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def do_classifier_free_guidance(self): | 
					
						
						|  | return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def cross_attention_kwargs(self): | 
					
						
						|  | return self._cross_attention_kwargs | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def num_timesteps(self): | 
					
						
						|  | return self._num_timesteps | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def interrupt(self): | 
					
						
						|  | return self._interrupt | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def __call__( | 
					
						
						|  | self, | 
					
						
						|  | prompt: Union[str, List[str]] = None, | 
					
						
						|  | height: Optional[int] = None, | 
					
						
						|  | width: Optional[int] = None, | 
					
						
						|  | num_inference_steps: int = 50, | 
					
						
						|  | timesteps: List[int] = None, | 
					
						
						|  | guidance_scale: float = 7.5, | 
					
						
						|  | negative_prompt: Optional[Union[str, List[str]]] = None, | 
					
						
						|  | num_images_per_prompt: Optional[int] = 1, | 
					
						
						|  | eta: float = 0.0, | 
					
						
						|  | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 
					
						
						|  | latents: Optional[torch.Tensor] = None, | 
					
						
						|  | prompt_embeds: Optional[torch.Tensor] = None, | 
					
						
						|  | negative_prompt_embeds: Optional[torch.Tensor] = None, | 
					
						
						|  | output_type: Optional[str] = "pil", | 
					
						
						|  | return_dict: bool = True, | 
					
						
						|  | cross_attention_kwargs: Optional[Dict[str, Any]] = None, | 
					
						
						|  | guidance_rescale: float = 0.0, | 
					
						
						|  | clip_skip: Optional[int] = None, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ): | 
					
						
						|  | r""" | 
					
						
						|  | The call function to the pipeline for generation. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | prompt (`str` or `List[str]`, *optional*): | 
					
						
						|  | The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. | 
					
						
						|  | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): | 
					
						
						|  | The height in pixels of the generated image. | 
					
						
						|  | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): | 
					
						
						|  | The width in pixels of the generated image. | 
					
						
						|  | num_inference_steps (`int`, *optional*, defaults to 50): | 
					
						
						|  | The number of denoising steps. More denoising steps usually lead to a higher quality image at the | 
					
						
						|  | expense of slower inference. | 
					
						
						|  | timesteps (`List[int]`, *optional*): | 
					
						
						|  | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument | 
					
						
						|  | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is | 
					
						
						|  | passed will be used. Must be in descending order. | 
					
						
						|  | guidance_scale (`float`, *optional*, defaults to 7.5): | 
					
						
						|  | A higher guidance scale value encourages the model to generate images closely linked to the text | 
					
						
						|  | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. | 
					
						
						|  | negative_prompt (`str` or `List[str]`, *optional*): | 
					
						
						|  | The prompt or prompts to guide what to not include in image generation. If not defined, you need to | 
					
						
						|  | pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). | 
					
						
						|  | num_images_per_prompt (`int`, *optional*, defaults to 1): | 
					
						
						|  | The number of images to generate per prompt. | 
					
						
						|  | eta (`float`, *optional*, defaults to 0.0): | 
					
						
						|  | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies | 
					
						
						|  | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. | 
					
						
						|  | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): | 
					
						
						|  | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make | 
					
						
						|  | generation deterministic. | 
					
						
						|  | latents (`torch.Tensor`, *optional*): | 
					
						
						|  | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image | 
					
						
						|  | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | 
					
						
						|  | tensor is generated by sampling using the supplied random `generator`. | 
					
						
						|  | prompt_embeds (`torch.Tensor`, *optional*): | 
					
						
						|  | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not | 
					
						
						|  | provided, text embeddings are generated from the `prompt` input argument. | 
					
						
						|  | negative_prompt_embeds (`torch.Tensor`, *optional*): | 
					
						
						|  | Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If | 
					
						
						|  | not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. | 
					
						
						|  | ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. | 
					
						
						|  | output_type (`str`, *optional*, defaults to `"pil"`): | 
					
						
						|  | The output format of the generated image. Choose between `PIL.Image` or `np.array`. | 
					
						
						|  | return_dict (`bool`, *optional*, defaults to `True`): | 
					
						
						|  | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | 
					
						
						|  | plain tuple. | 
					
						
						|  | cross_attention_kwargs (`dict`, *optional*): | 
					
						
						|  | A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in | 
					
						
						|  | [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | 
					
						
						|  | guidance_rescale (`float`, *optional*, defaults to 0.0): | 
					
						
						|  | Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are | 
					
						
						|  | Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when | 
					
						
						|  | using zero terminal SNR. | 
					
						
						|  | clip_skip (`int`, *optional*): | 
					
						
						|  | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that | 
					
						
						|  | the output of the pre-final layer will be used for computing the prompt embeddings. | 
					
						
						|  |  | 
					
						
						|  | Examples: | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: | 
					
						
						|  | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, | 
					
						
						|  | otherwise a `tuple` is returned where the first element is a list with the generated images and the | 
					
						
						|  | second element is a list of `bool`s indicating whether the corresponding generated image contains | 
					
						
						|  | "not-safe-for-work" (nsfw) content. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | height = height or self.unet.config.sample_size * self.vae_scale_factor | 
					
						
						|  | width = width or self.unet.config.sample_size * self.vae_scale_factor | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.check_inputs( | 
					
						
						|  | prompt, | 
					
						
						|  | height, | 
					
						
						|  | width, | 
					
						
						|  | negative_prompt, | 
					
						
						|  | prompt_embeds, | 
					
						
						|  | negative_prompt_embeds, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self._guidance_scale = guidance_scale | 
					
						
						|  | self._guidance_rescale = guidance_rescale | 
					
						
						|  | self._clip_skip = clip_skip | 
					
						
						|  | self._cross_attention_kwargs = cross_attention_kwargs | 
					
						
						|  | self._interrupt = False | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if prompt is not None and isinstance(prompt, str): | 
					
						
						|  | batch_size = 1 | 
					
						
						|  | elif prompt is not None and isinstance(prompt, list): | 
					
						
						|  | batch_size = len(prompt) | 
					
						
						|  | else: | 
					
						
						|  | batch_size = prompt_embeds.shape[0] | 
					
						
						|  |  | 
					
						
						|  | device = self._execution_device | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | lora_scale = ( | 
					
						
						|  | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | prompt_embeds, negative_prompt_embeds = self.encode_prompt( | 
					
						
						|  | prompt, | 
					
						
						|  | device, | 
					
						
						|  | num_images_per_prompt, | 
					
						
						|  | self.do_classifier_free_guidance, | 
					
						
						|  | negative_prompt, | 
					
						
						|  | prompt_embeds=prompt_embeds, | 
					
						
						|  | negative_prompt_embeds=negative_prompt_embeds, | 
					
						
						|  | lora_scale=lora_scale, | 
					
						
						|  | clip_skip=self.clip_skip, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.do_classifier_free_guidance: | 
					
						
						|  | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | num_channels_latents = self.unet.config.in_channels | 
					
						
						|  | latents = self.prepare_latents( | 
					
						
						|  | batch_size * num_images_per_prompt, | 
					
						
						|  | num_channels_latents, | 
					
						
						|  | height, | 
					
						
						|  | width, | 
					
						
						|  | prompt_embeds.dtype, | 
					
						
						|  | device, | 
					
						
						|  | generator, | 
					
						
						|  | latents, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | timestep_cond = None | 
					
						
						|  | if self.unet.config.time_cond_proj_dim is not None: | 
					
						
						|  | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) | 
					
						
						|  | timestep_cond = self.get_guidance_scale_embedding( | 
					
						
						|  | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim | 
					
						
						|  | ).to(device=device, dtype=latents.dtype) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | 
					
						
						|  | self._num_timesteps = len(timesteps) | 
					
						
						|  | with self.progress_bar(total=num_inference_steps) as progress_bar: | 
					
						
						|  | for i, t in enumerate(timesteps): | 
					
						
						|  | if self.interrupt: | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents | 
					
						
						|  | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | noise_pred = self.unet( | 
					
						
						|  | latent_model_input, | 
					
						
						|  | t, | 
					
						
						|  | encoder_hidden_states=prompt_embeds, | 
					
						
						|  | timestep_cond=timestep_cond, | 
					
						
						|  | cross_attention_kwargs=self.cross_attention_kwargs, | 
					
						
						|  | return_dict=False, | 
					
						
						|  | )[0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.do_classifier_free_guidance: | 
					
						
						|  | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | 
					
						
						|  | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) | 
					
						
						|  |  | 
					
						
						|  | if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: | 
					
						
						|  |  | 
					
						
						|  | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | 
					
						
						|  | progress_bar.update() | 
					
						
						|  |  | 
					
						
						|  | if not output_type == "latent": | 
					
						
						|  | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ | 
					
						
						|  | 0 | 
					
						
						|  | ] | 
					
						
						|  | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) | 
					
						
						|  | else: | 
					
						
						|  | image = latents | 
					
						
						|  | has_nsfw_concept = None | 
					
						
						|  |  | 
					
						
						|  | if has_nsfw_concept is None: | 
					
						
						|  | do_denormalize = [True] * image.shape[0] | 
					
						
						|  | else: | 
					
						
						|  | do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] | 
					
						
						|  |  | 
					
						
						|  | image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.maybe_free_model_hooks() | 
					
						
						|  |  | 
					
						
						|  | if not return_dict: | 
					
						
						|  | return (image, has_nsfw_concept) | 
					
						
						|  |  | 
					
						
						|  | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) | 
					
						
						|  |  |