naonauno commited on
Commit
3f525da
·
verified ·
1 Parent(s): 7e720cc

Upload 2 files

Browse files
Files changed (2) hide show
  1. model.py +303 -0
  2. pipeline.py +1377 -0
model.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+
3
+ import copy
4
+ import torch
5
+ from torch import nn, svd_lowrank
6
+
7
+ from peft.tuners.lora import LoraLayer, Conv2d as PeftConv2d
8
+ from diffusers.configuration_utils import register_to_config
9
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput, UNet2DConditionModel as UNet2DConditionModel
10
+
11
+
12
+ class UNet2DConditionModelEx(UNet2DConditionModel):
13
+ @register_to_config
14
+ def __init__(
15
+ self,
16
+ sample_size: Optional[int] = None,
17
+ in_channels: int = 4,
18
+ out_channels: int = 4,
19
+ center_input_sample: bool = False,
20
+ flip_sin_to_cos: bool = True,
21
+ freq_shift: int = 0,
22
+ down_block_types: Tuple[str] = (
23
+ "CrossAttnDownBlock2D",
24
+ "CrossAttnDownBlock2D",
25
+ "CrossAttnDownBlock2D",
26
+ "DownBlock2D",
27
+ ),
28
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
29
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
30
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
31
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
32
+ layers_per_block: Union[int, Tuple[int]] = 2,
33
+ downsample_padding: int = 1,
34
+ mid_block_scale_factor: float = 1,
35
+ dropout: float = 0.0,
36
+ act_fn: str = "silu",
37
+ norm_num_groups: Optional[int] = 32,
38
+ norm_eps: float = 1e-5,
39
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
40
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
41
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
42
+ encoder_hid_dim: Optional[int] = None,
43
+ encoder_hid_dim_type: Optional[str] = None,
44
+ attention_head_dim: Union[int, Tuple[int]] = 8,
45
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
46
+ dual_cross_attention: bool = False,
47
+ use_linear_projection: bool = False,
48
+ class_embed_type: Optional[str] = None,
49
+ addition_embed_type: Optional[str] = None,
50
+ addition_time_embed_dim: Optional[int] = None,
51
+ num_class_embeds: Optional[int] = None,
52
+ upcast_attention: bool = False,
53
+ resnet_time_scale_shift: str = "default",
54
+ resnet_skip_time_act: bool = False,
55
+ resnet_out_scale_factor: float = 1.0,
56
+ time_embedding_type: str = "positional",
57
+ time_embedding_dim: Optional[int] = None,
58
+ time_embedding_act_fn: Optional[str] = None,
59
+ timestep_post_act: Optional[str] = None,
60
+ time_cond_proj_dim: Optional[int] = None,
61
+ conv_in_kernel: int = 3,
62
+ conv_out_kernel: int = 3,
63
+ projection_class_embeddings_input_dim: Optional[int] = None,
64
+ attention_type: str = "default",
65
+ class_embeddings_concat: bool = False,
66
+ mid_block_only_cross_attention: Optional[bool] = None,
67
+ cross_attention_norm: Optional[str] = None,
68
+ addition_embed_type_num_heads: int = 64,
69
+ extra_condition_names: List[str] = [],
70
+ ):
71
+ num_extra_conditions = len(extra_condition_names)
72
+ super().__init__(
73
+ sample_size=sample_size,
74
+ in_channels=in_channels * (1 + num_extra_conditions),
75
+ out_channels=out_channels,
76
+ center_input_sample=center_input_sample,
77
+ flip_sin_to_cos=flip_sin_to_cos,
78
+ freq_shift=freq_shift,
79
+ down_block_types=down_block_types,
80
+ mid_block_type=mid_block_type,
81
+ up_block_types=up_block_types,
82
+ only_cross_attention=only_cross_attention,
83
+ block_out_channels=block_out_channels,
84
+ layers_per_block=layers_per_block,
85
+ downsample_padding=downsample_padding,
86
+ mid_block_scale_factor=mid_block_scale_factor,
87
+ dropout=dropout,
88
+ act_fn=act_fn,
89
+ norm_num_groups=norm_num_groups,
90
+ norm_eps=norm_eps,
91
+ cross_attention_dim=cross_attention_dim,
92
+ transformer_layers_per_block=transformer_layers_per_block,
93
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
94
+ encoder_hid_dim=encoder_hid_dim,
95
+ encoder_hid_dim_type=encoder_hid_dim_type,
96
+ attention_head_dim=attention_head_dim,
97
+ num_attention_heads=num_attention_heads,
98
+ dual_cross_attention=dual_cross_attention,
99
+ use_linear_projection=use_linear_projection,
100
+ class_embed_type=class_embed_type,
101
+ addition_embed_type=addition_embed_type,
102
+ addition_time_embed_dim=addition_time_embed_dim,
103
+ num_class_embeds=num_class_embeds,
104
+ upcast_attention=upcast_attention,
105
+ resnet_time_scale_shift=resnet_time_scale_shift,
106
+ resnet_skip_time_act=resnet_skip_time_act,
107
+ resnet_out_scale_factor=resnet_out_scale_factor,
108
+ time_embedding_type=time_embedding_type,
109
+ time_embedding_dim=time_embedding_dim,
110
+ time_embedding_act_fn=time_embedding_act_fn,
111
+ timestep_post_act=timestep_post_act,
112
+ time_cond_proj_dim=time_cond_proj_dim,
113
+ conv_in_kernel=conv_in_kernel,
114
+ conv_out_kernel=conv_out_kernel,
115
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
116
+ attention_type=attention_type,
117
+ class_embeddings_concat=class_embeddings_concat,
118
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
119
+ cross_attention_norm=cross_attention_norm,
120
+ addition_embed_type_num_heads=addition_embed_type_num_heads,)
121
+ self._internal_dict = copy.deepcopy(self._internal_dict)
122
+ self.config.in_channels = in_channels
123
+ self.config.extra_condition_names = extra_condition_names
124
+
125
+ @property
126
+ def extra_condition_names(self) -> List[str]:
127
+ return self.config.extra_condition_names
128
+
129
+ def add_extra_conditions(self, extra_condition_names: Union[str, List[str]]):
130
+ if isinstance(extra_condition_names, str):
131
+ extra_condition_names = [extra_condition_names]
132
+ conv_in_kernel = self.config.conv_in_kernel
133
+ conv_in_weight = self.conv_in.weight
134
+ self.config.extra_condition_names += extra_condition_names
135
+ full_in_channels = self.config.in_channels * (1 + len(self.config.extra_condition_names))
136
+ new_conv_in_weight = torch.zeros(
137
+ conv_in_weight.shape[0], full_in_channels, conv_in_kernel, conv_in_kernel,
138
+ dtype=conv_in_weight.dtype,
139
+ device=conv_in_weight.device,)
140
+ new_conv_in_weight[:,:conv_in_weight.shape[1]] = conv_in_weight
141
+ self.conv_in.weight = nn.Parameter(
142
+ new_conv_in_weight.data,
143
+ requires_grad=conv_in_weight.requires_grad,)
144
+ self.conv_in.in_channels = full_in_channels
145
+
146
+ return self
147
+
148
+ def activate_extra_condition_adapters(self):
149
+ lora_layers = [layer for layer in self.modules() if isinstance(layer, LoraLayer)]
150
+ if len(lora_layers) > 0:
151
+ self._hf_peft_config_loaded = True
152
+ for lora_layer in lora_layers:
153
+ adapter_names = [k for k in lora_layer.scaling.keys() if k in self.config.extra_condition_names]
154
+ adapter_names += lora_layer.active_adapters
155
+ adapter_names = list(set(adapter_names))
156
+ lora_layer.set_adapter(adapter_names)
157
+
158
+ def set_extra_condition_scale(self, scale: Union[float, List[float]] = 1.0):
159
+ if isinstance(scale, float):
160
+ scale = [scale] * len(self.config.extra_condition_names)
161
+
162
+ lora_layers = [layer for layer in self.modules() if isinstance(layer, LoraLayer)]
163
+ for s, n in zip(scale, self.config.extra_condition_names):
164
+ for lora_layer in lora_layers:
165
+ lora_layer.set_scale(n, s)
166
+
167
+ @property
168
+ def default_half_lora_target_modules(self) -> List[str]:
169
+ module_names = []
170
+ for name, module in self.named_modules():
171
+ if "conv_out" in name or "up_blocks" in name:
172
+ continue
173
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
174
+ module_names.append(name)
175
+ return list(set(module_names))
176
+
177
+ @property
178
+ def default_full_lora_target_modules(self) -> List[str]:
179
+ module_names = []
180
+ for name, module in self.named_modules():
181
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
182
+ module_names.append(name)
183
+ return list(set(module_names))
184
+
185
+ @property
186
+ def default_half_skip_attn_lora_target_modules(self) -> List[str]:
187
+ return [
188
+ module_name
189
+ for module_name in self.default_half_lora_target_modules
190
+ if all(
191
+ not module_name.endswith(attn_name)
192
+ for attn_name in
193
+ ["to_k", "to_q", "to_v", "to_out.0"]
194
+ )
195
+ ]
196
+
197
+ @property
198
+ def default_full_skip_attn_lora_target_modules(self) -> List[str]:
199
+ return [
200
+ module_name
201
+ for module_name in self.default_full_lora_target_modules
202
+ if all(
203
+ not module_name.endswith(attn_name)
204
+ for attn_name in
205
+ ["to_k", "to_q", "to_v", "to_out.0"]
206
+ )
207
+ ]
208
+
209
+ def forward(
210
+ self,
211
+ sample: torch.Tensor,
212
+ timestep: Union[torch.Tensor, float, int],
213
+ encoder_hidden_states: torch.Tensor,
214
+ class_labels: Optional[torch.Tensor] = None,
215
+ timestep_cond: Optional[torch.Tensor] = None,
216
+ attention_mask: Optional[torch.Tensor] = None,
217
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
218
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
219
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
220
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
221
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
222
+ encoder_attention_mask: Optional[torch.Tensor] = None,
223
+ extra_conditions: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
224
+ return_dict: bool = True,
225
+ ) -> Union[UNet2DConditionOutput, Tuple]:
226
+ if extra_conditions is not None:
227
+ if isinstance(extra_conditions, list):
228
+ extra_conditions = torch.cat(extra_conditions, dim=1)
229
+ sample = torch.cat([sample, extra_conditions], dim=1)
230
+ return super().forward(
231
+ sample=sample,
232
+ timestep=timestep,
233
+ encoder_hidden_states=encoder_hidden_states,
234
+ class_labels=class_labels,
235
+ timestep_cond=timestep_cond,
236
+ attention_mask=attention_mask,
237
+ cross_attention_kwargs=cross_attention_kwargs,
238
+ added_cond_kwargs=added_cond_kwargs,
239
+ down_block_additional_residuals=down_block_additional_residuals,
240
+ mid_block_additional_residual=mid_block_additional_residual,
241
+ down_intrablock_additional_residuals=down_intrablock_additional_residuals,
242
+ encoder_attention_mask=encoder_attention_mask,
243
+ return_dict=return_dict,)
244
+
245
+
246
+ class PeftConv2dEx(PeftConv2d):
247
+ def reset_lora_parameters(self, adapter_name, init_lora_weights):
248
+ if init_lora_weights is False:
249
+ return
250
+
251
+ if isinstance(init_lora_weights, str) and "pissa" in init_lora_weights.lower():
252
+ if self.conv2d_pissa_init(adapter_name, init_lora_weights):
253
+ return
254
+ # Failed
255
+ init_lora_weights = "gaussian"
256
+
257
+ super(PeftConv2d, self).reset_lora_parameters(adapter_name, init_lora_weights)
258
+
259
+ def conv2d_pissa_init(self, adapter_name, init_lora_weights):
260
+ weight = weight_ori = self.get_base_layer().weight
261
+ weight = weight.flatten(start_dim=1)
262
+ if self.r[adapter_name] > weight.shape[0]:
263
+ return False
264
+ dtype = weight.dtype
265
+ if dtype not in [torch.float32, torch.float16, torch.bfloat16]:
266
+ raise TypeError(
267
+ "Please initialize PiSSA under float32, float16, or bfloat16. "
268
+ "Subsequently, re-quantize the residual model to help minimize quantization errors."
269
+ )
270
+ weight = weight.to(torch.float32)
271
+
272
+ if init_lora_weights == "pissa":
273
+ # USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
274
+ V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False)
275
+ Vr = V[:, : self.r[adapter_name]]
276
+ Sr = S[: self.r[adapter_name]]
277
+ Sr /= self.scaling[adapter_name]
278
+ Uhr = Uh[: self.r[adapter_name]]
279
+ elif len(init_lora_weights.split("_niter_")) == 2:
280
+ Vr, Sr, Ur = svd_lowrank(
281
+ weight.data, self.r[adapter_name], niter=int(init_lora_weights.split("_niter_")[-1])
282
+ )
283
+ Sr /= self.scaling[adapter_name]
284
+ Uhr = Ur.t()
285
+ else:
286
+ raise ValueError(
287
+ f"init_lora_weights should be 'pissa' or 'pissa_niter_[number of iters]', got {init_lora_weights} instead."
288
+ )
289
+
290
+ lora_A = torch.diag(torch.sqrt(Sr)) @ Uhr
291
+ lora_B = Vr @ torch.diag(torch.sqrt(Sr))
292
+ self.lora_A[adapter_name].weight.data = lora_A.view([-1] + list(weight_ori.shape[1:]))
293
+ self.lora_B[adapter_name].weight.data = lora_B.view([-1, self.r[adapter_name]] + [1] * (weight_ori.ndim - 2))
294
+ weight = weight.data - self.scaling[adapter_name] * lora_B @ lora_A
295
+ weight = weight.to(dtype)
296
+ self.get_base_layer().weight.data = weight.view_as(weight_ori)
297
+
298
+ return True
299
+
300
+
301
+ # Patch peft conv2d
302
+ PeftConv2d.reset_lora_parameters = PeftConv2dEx.reset_lora_parameters
303
+ PeftConv2d.conv2d_pissa_init = PeftConv2dEx.conv2d_pissa_init
pipeline.py ADDED
@@ -0,0 +1,1377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
9
+
10
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
11
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
12
+ from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
13
+ from diffusers.models import AutoencoderKL, ImageProjection
14
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
15
+ from diffusers.schedulers import KarrasDiffusionSchedulers
16
+ from diffusers.utils import (
17
+ USE_PEFT_BACKEND,
18
+ deprecate,
19
+ logging,
20
+ replace_example_docstring,
21
+ scale_lora_layers,
22
+ unscale_lora_layers,
23
+ )
24
+ from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
26
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
27
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
28
+ from model import UNet2DConditionModelEx
29
+
30
+
31
+ from huggingface_hub.utils import validate_hf_hub_args
32
+
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ EXAMPLE_DOC_STRING = """
38
+ Examples:
39
+ ```py
40
+ >>> # !pip install opencv-python transformers accelerate
41
+ >>> from diffusers import UniPCMultistepScheduler
42
+ >>> from diffusers.utils import load_image
43
+ >>> from model import UNet2DConditionModelEx
44
+ >>> from pipeline import StableDiffusionControlLoraV3Pipeline
45
+ >>> import numpy as np
46
+ >>> import torch
47
+
48
+ >>> import cv2
49
+ >>> from PIL import Image
50
+
51
+ >>> # download an image
52
+ >>> image = load_image(
53
+ ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
54
+ ... )
55
+ >>> image = np.array(image)
56
+
57
+ >>> # get canny image
58
+ >>> image = cv2.Canny(image, 100, 200)
59
+ >>> image = image[:, :, None]
60
+ >>> image = np.concatenate([image, image, image], axis=2)
61
+ >>> canny_image = Image.fromarray(image)
62
+
63
+ >>> # load stable diffusion v1-5 and control-lora-v3
64
+ >>> unet: UNet2DConditionModelEx = UNet2DConditionModelEx.from_pretrained(
65
+ ... "runwayml/stable-diffusion-v1-5", subfolder="unet", torch_dtype=torch.float16
66
+ ... )
67
+ >>> unet = unet.add_extra_conditions(["canny"])
68
+ >>> pipe = StableDiffusionControlLoraV3Pipeline.from_pretrained(
69
+ ... "runwayml/stable-diffusion-v1-5", unet=unet, torch_dtype=torch.float16
70
+ ... )
71
+ >>> # load attention processors
72
+ >>> pipe.load_lora_weights("HighCWu/sd-control-lora-v3-canny")
73
+
74
+ >>> # speed up diffusion process with faster scheduler and memory optimization
75
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
76
+ >>> # remove following line if xformers is not installed
77
+ >>> pipe.enable_xformers_memory_efficient_attention()
78
+
79
+ >>> pipe.enable_model_cpu_offload()
80
+
81
+ >>> # generate image
82
+ >>> generator = torch.manual_seed(0)
83
+ >>> image = pipe(
84
+ ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image
85
+ ... ).images[0]
86
+ ```
87
+ """
88
+
89
+
90
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
91
+ def retrieve_timesteps(
92
+ scheduler,
93
+ num_inference_steps: Optional[int] = None,
94
+ device: Optional[Union[str, torch.device]] = None,
95
+ timesteps: Optional[List[int]] = None,
96
+ sigmas: Optional[List[float]] = None,
97
+ **kwargs,
98
+ ):
99
+ """
100
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
101
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
102
+
103
+ Args:
104
+ scheduler (`SchedulerMixin`):
105
+ The scheduler to get timesteps from.
106
+ num_inference_steps (`int`):
107
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
108
+ must be `None`.
109
+ device (`str` or `torch.device`, *optional*):
110
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
111
+ timesteps (`List[int]`, *optional*):
112
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
113
+ `num_inference_steps` and `sigmas` must be `None`.
114
+ sigmas (`List[float]`, *optional*):
115
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
116
+ `num_inference_steps` and `timesteps` must be `None`.
117
+
118
+ Returns:
119
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
120
+ second element is the number of inference steps.
121
+ """
122
+ if timesteps is not None and sigmas is not None:
123
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
124
+ if timesteps is not None:
125
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
126
+ if not accepts_timesteps:
127
+ raise ValueError(
128
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
129
+ f" timestep schedules. Please check whether you are using the correct scheduler."
130
+ )
131
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
132
+ timesteps = scheduler.timesteps
133
+ num_inference_steps = len(timesteps)
134
+ elif sigmas is not None:
135
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
136
+ if not accept_sigmas:
137
+ raise ValueError(
138
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
139
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
140
+ )
141
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
142
+ timesteps = scheduler.timesteps
143
+ num_inference_steps = len(timesteps)
144
+ else:
145
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
146
+ timesteps = scheduler.timesteps
147
+ return timesteps, num_inference_steps
148
+
149
+
150
+ class StableDiffusionControlLoraV3Pipeline(
151
+ DiffusionPipeline,
152
+ StableDiffusionMixin,
153
+ TextualInversionLoaderMixin,
154
+ LoraLoaderMixin,
155
+ IPAdapterMixin,
156
+ FromSingleFileMixin,
157
+ ):
158
+ r"""
159
+ Pipeline for text-to-image generation using Stable Diffusion with extra condition guidance.
160
+
161
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
162
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
163
+
164
+ The pipeline also inherits the following loading methods:
165
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
166
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
167
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
168
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
169
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
170
+
171
+ Args:
172
+ vae ([`AutoencoderKL`]):
173
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
174
+ text_encoder ([`~transformers.CLIPTextModel`]):
175
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
176
+ tokenizer ([`~transformers.CLIPTokenizer`]):
177
+ A `CLIPTokenizer` to tokenize text.
178
+ unet ([`UNet2DConditionModelEx`]):
179
+ A `UNet2DConditionModelEx` to denoise the encoded image latents with extra conditions.
180
+ scheduler ([`SchedulerMixin`]):
181
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
182
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
183
+ safety_checker ([`StableDiffusionSafetyChecker`]):
184
+ Classification module that estimates whether generated images could be considered offensive or harmful.
185
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
186
+ about a model's potential harms.
187
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
188
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
189
+ """
190
+
191
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
192
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
193
+ _exclude_from_cpu_offload = ["safety_checker"]
194
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
195
+
196
+ def __init__(
197
+ self,
198
+ vae: AutoencoderKL,
199
+ text_encoder: CLIPTextModel,
200
+ tokenizer: CLIPTokenizer,
201
+ unet: UNet2DConditionModelEx,
202
+ scheduler: KarrasDiffusionSchedulers,
203
+ safety_checker: StableDiffusionSafetyChecker,
204
+ feature_extractor: CLIPImageProcessor,
205
+ image_encoder: CLIPVisionModelWithProjection = None,
206
+ requires_safety_checker: bool = True,
207
+ ):
208
+ super().__init__()
209
+
210
+ if safety_checker is None and requires_safety_checker:
211
+ logger.warning(
212
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
213
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
214
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
215
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
216
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
217
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
218
+ )
219
+
220
+ if safety_checker is not None and feature_extractor is None:
221
+ raise ValueError(
222
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
223
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
224
+ )
225
+
226
+ self.register_modules(
227
+ vae=vae,
228
+ text_encoder=text_encoder,
229
+ tokenizer=tokenizer,
230
+ unet=unet,
231
+ scheduler=scheduler,
232
+ safety_checker=safety_checker,
233
+ feature_extractor=feature_extractor,
234
+ image_encoder=image_encoder,
235
+ )
236
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
237
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
238
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
239
+
240
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
241
+ def _encode_prompt(
242
+ self,
243
+ prompt,
244
+ device,
245
+ num_images_per_prompt,
246
+ do_classifier_free_guidance,
247
+ negative_prompt=None,
248
+ prompt_embeds: Optional[torch.Tensor] = None,
249
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
250
+ lora_scale: Optional[float] = None,
251
+ **kwargs,
252
+ ):
253
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
254
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
255
+
256
+ prompt_embeds_tuple = self.encode_prompt(
257
+ prompt=prompt,
258
+ device=device,
259
+ num_images_per_prompt=num_images_per_prompt,
260
+ do_classifier_free_guidance=do_classifier_free_guidance,
261
+ negative_prompt=negative_prompt,
262
+ prompt_embeds=prompt_embeds,
263
+ negative_prompt_embeds=negative_prompt_embeds,
264
+ lora_scale=lora_scale,
265
+ **kwargs,
266
+ )
267
+
268
+ # concatenate for backwards comp
269
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
270
+
271
+ return prompt_embeds
272
+
273
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
274
+ def encode_prompt(
275
+ self,
276
+ prompt,
277
+ device,
278
+ num_images_per_prompt,
279
+ do_classifier_free_guidance,
280
+ negative_prompt=None,
281
+ prompt_embeds: Optional[torch.Tensor] = None,
282
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
283
+ lora_scale: Optional[float] = None,
284
+ clip_skip: Optional[int] = None,
285
+ ):
286
+ r"""
287
+ Encodes the prompt into text encoder hidden states.
288
+
289
+ Args:
290
+ prompt (`str` or `List[str]`, *optional*):
291
+ prompt to be encoded
292
+ device: (`torch.device`):
293
+ torch device
294
+ num_images_per_prompt (`int`):
295
+ number of images that should be generated per prompt
296
+ do_classifier_free_guidance (`bool`):
297
+ whether to use classifier free guidance or not
298
+ negative_prompt (`str` or `List[str]`, *optional*):
299
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
300
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
301
+ less than `1`).
302
+ prompt_embeds (`torch.Tensor`, *optional*):
303
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
304
+ provided, text embeddings will be generated from `prompt` input argument.
305
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
306
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
307
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
308
+ argument.
309
+ lora_scale (`float`, *optional*):
310
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
311
+ clip_skip (`int`, *optional*):
312
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
313
+ the output of the pre-final layer will be used for computing the prompt embeddings.
314
+ """
315
+ # set lora scale so that monkey patched LoRA
316
+ # function of text encoder can correctly access it
317
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
318
+ self._lora_scale = lora_scale
319
+
320
+ # dynamically adjust the LoRA scale
321
+ if not USE_PEFT_BACKEND:
322
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
323
+ else:
324
+ scale_lora_layers(self.text_encoder, lora_scale)
325
+
326
+ if prompt is not None and isinstance(prompt, str):
327
+ batch_size = 1
328
+ elif prompt is not None and isinstance(prompt, list):
329
+ batch_size = len(prompt)
330
+ else:
331
+ batch_size = prompt_embeds.shape[0]
332
+
333
+ if prompt_embeds is None:
334
+ # textual inversion: process multi-vector tokens if necessary
335
+ if isinstance(self, TextualInversionLoaderMixin):
336
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
337
+
338
+ text_inputs = self.tokenizer(
339
+ prompt,
340
+ padding="max_length",
341
+ max_length=self.tokenizer.model_max_length,
342
+ truncation=True,
343
+ return_tensors="pt",
344
+ )
345
+ text_input_ids = text_inputs.input_ids
346
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
347
+
348
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
349
+ text_input_ids, untruncated_ids
350
+ ):
351
+ removed_text = self.tokenizer.batch_decode(
352
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
353
+ )
354
+ logger.warning(
355
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
356
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
357
+ )
358
+
359
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
360
+ attention_mask = text_inputs.attention_mask.to(device)
361
+ else:
362
+ attention_mask = None
363
+
364
+ if clip_skip is None:
365
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
366
+ prompt_embeds = prompt_embeds[0]
367
+ else:
368
+ prompt_embeds = self.text_encoder(
369
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
370
+ )
371
+ # Access the `hidden_states` first, that contains a tuple of
372
+ # all the hidden states from the encoder layers. Then index into
373
+ # the tuple to access the hidden states from the desired layer.
374
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
375
+ # We also need to apply the final LayerNorm here to not mess with the
376
+ # representations. The `last_hidden_states` that we typically use for
377
+ # obtaining the final prompt representations passes through the LayerNorm
378
+ # layer.
379
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
380
+
381
+ if self.text_encoder is not None:
382
+ prompt_embeds_dtype = self.text_encoder.dtype
383
+ elif self.unet is not None:
384
+ prompt_embeds_dtype = self.unet.dtype
385
+ else:
386
+ prompt_embeds_dtype = prompt_embeds.dtype
387
+
388
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
389
+
390
+ bs_embed, seq_len, _ = prompt_embeds.shape
391
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
392
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
393
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
394
+
395
+ # get unconditional embeddings for classifier free guidance
396
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
397
+ uncond_tokens: List[str]
398
+ if negative_prompt is None:
399
+ uncond_tokens = [""] * batch_size
400
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
401
+ raise TypeError(
402
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
403
+ f" {type(prompt)}."
404
+ )
405
+ elif isinstance(negative_prompt, str):
406
+ uncond_tokens = [negative_prompt]
407
+ elif batch_size != len(negative_prompt):
408
+ raise ValueError(
409
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
410
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
411
+ " the batch size of `prompt`."
412
+ )
413
+ else:
414
+ uncond_tokens = negative_prompt
415
+
416
+ # textual inversion: process multi-vector tokens if necessary
417
+ if isinstance(self, TextualInversionLoaderMixin):
418
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
419
+
420
+ max_length = prompt_embeds.shape[1]
421
+ uncond_input = self.tokenizer(
422
+ uncond_tokens,
423
+ padding="max_length",
424
+ max_length=max_length,
425
+ truncation=True,
426
+ return_tensors="pt",
427
+ )
428
+
429
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
430
+ attention_mask = uncond_input.attention_mask.to(device)
431
+ else:
432
+ attention_mask = None
433
+
434
+ negative_prompt_embeds = self.text_encoder(
435
+ uncond_input.input_ids.to(device),
436
+ attention_mask=attention_mask,
437
+ )
438
+ negative_prompt_embeds = negative_prompt_embeds[0]
439
+
440
+ if do_classifier_free_guidance:
441
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
442
+ seq_len = negative_prompt_embeds.shape[1]
443
+
444
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
445
+
446
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
447
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
448
+
449
+ if self.text_encoder is not None:
450
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
451
+ # Retrieve the original scale by scaling back the LoRA layers
452
+ unscale_lora_layers(self.text_encoder, lora_scale)
453
+
454
+ return prompt_embeds, negative_prompt_embeds
455
+
456
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
457
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
458
+ dtype = next(self.image_encoder.parameters()).dtype
459
+
460
+ if not isinstance(image, torch.Tensor):
461
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
462
+
463
+ image = image.to(device=device, dtype=dtype)
464
+ if output_hidden_states:
465
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
466
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
467
+ uncond_image_enc_hidden_states = self.image_encoder(
468
+ torch.zeros_like(image), output_hidden_states=True
469
+ ).hidden_states[-2]
470
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
471
+ num_images_per_prompt, dim=0
472
+ )
473
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
474
+ else:
475
+ image_embeds = self.image_encoder(image).image_embeds
476
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
477
+ uncond_image_embeds = torch.zeros_like(image_embeds)
478
+
479
+ return image_embeds, uncond_image_embeds
480
+
481
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
482
+ def prepare_ip_adapter_image_embeds(
483
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
484
+ ):
485
+ if ip_adapter_image_embeds is None:
486
+ if not isinstance(ip_adapter_image, list):
487
+ ip_adapter_image = [ip_adapter_image]
488
+
489
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
490
+ raise ValueError(
491
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
492
+ )
493
+
494
+ image_embeds = []
495
+ for single_ip_adapter_image, image_proj_layer in zip(
496
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
497
+ ):
498
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
499
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
500
+ single_ip_adapter_image, device, 1, output_hidden_state
501
+ )
502
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
503
+ single_negative_image_embeds = torch.stack(
504
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
505
+ )
506
+
507
+ if do_classifier_free_guidance:
508
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
509
+ single_image_embeds = single_image_embeds.to(device)
510
+
511
+ image_embeds.append(single_image_embeds)
512
+ else:
513
+ repeat_dims = [1]
514
+ image_embeds = []
515
+ for single_image_embeds in ip_adapter_image_embeds:
516
+ if do_classifier_free_guidance:
517
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
518
+ single_image_embeds = single_image_embeds.repeat(
519
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
520
+ )
521
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
522
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
523
+ )
524
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
525
+ else:
526
+ single_image_embeds = single_image_embeds.repeat(
527
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
528
+ )
529
+ image_embeds.append(single_image_embeds)
530
+
531
+ return image_embeds
532
+
533
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
534
+ def run_safety_checker(self, image, device, dtype):
535
+ if self.safety_checker is None:
536
+ has_nsfw_concept = None
537
+ else:
538
+ if torch.is_tensor(image):
539
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
540
+ else:
541
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
542
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
543
+ image, has_nsfw_concept = self.safety_checker(
544
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
545
+ )
546
+ return image, has_nsfw_concept
547
+
548
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
549
+ def decode_latents(self, latents):
550
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
551
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
552
+
553
+ latents = 1 / self.vae.config.scaling_factor * latents
554
+ image = self.vae.decode(latents, return_dict=False)[0]
555
+ image = (image / 2 + 0.5).clamp(0, 1)
556
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
557
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
558
+ return image
559
+
560
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
561
+ def prepare_extra_step_kwargs(self, generator, eta):
562
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
563
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
564
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
565
+ # and should be between [0, 1]
566
+
567
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
568
+ extra_step_kwargs = {}
569
+ if accepts_eta:
570
+ extra_step_kwargs["eta"] = eta
571
+
572
+ # check if the scheduler accepts generator
573
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
574
+ if accepts_generator:
575
+ extra_step_kwargs["generator"] = generator
576
+ return extra_step_kwargs
577
+
578
+ def check_inputs(
579
+ self,
580
+ prompt,
581
+ image,
582
+ callback_steps,
583
+ negative_prompt=None,
584
+ prompt_embeds=None,
585
+ negative_prompt_embeds=None,
586
+ ip_adapter_image=None,
587
+ ip_adapter_image_embeds=None,
588
+ extra_condition_scale=1.0,
589
+ control_guidance_start=0.0,
590
+ control_guidance_end=1.0,
591
+ callback_on_step_end_tensor_inputs=None,
592
+ ):
593
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
594
+ raise ValueError(
595
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
596
+ f" {type(callback_steps)}."
597
+ )
598
+
599
+ if callback_on_step_end_tensor_inputs is not None and not all(
600
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
601
+ ):
602
+ raise ValueError(
603
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
604
+ )
605
+
606
+ if prompt is not None and prompt_embeds is not None:
607
+ raise ValueError(
608
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
609
+ " only forward one of the two."
610
+ )
611
+ elif prompt is None and prompt_embeds is None:
612
+ raise ValueError(
613
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
614
+ )
615
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
616
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
617
+
618
+ if negative_prompt is not None and negative_prompt_embeds is not None:
619
+ raise ValueError(
620
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
621
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
622
+ )
623
+
624
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
625
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
626
+ raise ValueError(
627
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
628
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
629
+ f" {negative_prompt_embeds.shape}."
630
+ )
631
+
632
+ # Check `image`
633
+ unet: UNet2DConditionModelEx = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
634
+ num_extra_conditions = len(unet.extra_condition_names)
635
+ if num_extra_conditions == 1:
636
+ self.check_image(image, prompt, prompt_embeds)
637
+ elif num_extra_conditions > 1:
638
+ if not isinstance(image, list):
639
+ raise TypeError("For multiple extra conditions: `image` must be type `list`")
640
+
641
+ # When `image` is a nested list:
642
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
643
+ elif any(isinstance(i, list) for i in image):
644
+ transposed_image = [list(t) for t in zip(*image)]
645
+ if len(transposed_image) != num_extra_conditions:
646
+ raise ValueError(
647
+ f"For multiple extra conditions: if you pass`image` as a list of list, each sublist must have the same length as the number of extra conditions, but the sublists in `image` got {len(transposed_image)} images and {num_extra_conditions} extra conditions."
648
+ )
649
+ for image_ in transposed_image:
650
+ self.check_image(image_, prompt, prompt_embeds)
651
+ elif len(image) != num_extra_conditions:
652
+ raise ValueError(
653
+ f"For multiple extra conditions: `image` must have the same length as the number of extra conditions, but got {len(image)} images and {num_extra_conditions} extra conditions."
654
+ )
655
+ else:
656
+ for image_ in image:
657
+ self.check_image(image_, prompt, prompt_embeds)
658
+ else:
659
+ assert False
660
+
661
+ # Check `extra_condition_scale`
662
+ if num_extra_conditions == 1:
663
+ if not isinstance(extra_condition_scale, float):
664
+ raise TypeError("For single extra condition: `extra_condition_scale` must be type `float`.")
665
+ elif num_extra_conditions >= 1:
666
+ if isinstance(extra_condition_scale, list):
667
+ if any(isinstance(i, list) for i in extra_condition_scale):
668
+ raise ValueError(
669
+ "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. "
670
+ "The conditioning scale must be fixed across the batch."
671
+ )
672
+ elif isinstance(extra_condition_scale, list) and len(extra_condition_scale) != num_extra_conditions:
673
+ raise ValueError(
674
+ "For multiple extra conditions: When `extra_condition_scale` is specified as `list`, it must have"
675
+ " the same length as the number of extra conditions"
676
+ )
677
+ else:
678
+ assert False
679
+
680
+ if not isinstance(control_guidance_start, (tuple, list)):
681
+ control_guidance_start = [control_guidance_start]
682
+
683
+ if not isinstance(control_guidance_end, (tuple, list)):
684
+ control_guidance_end = [control_guidance_end]
685
+
686
+ if len(control_guidance_start) != len(control_guidance_end):
687
+ raise ValueError(
688
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
689
+ )
690
+
691
+ if num_extra_conditions > 1:
692
+ if len(control_guidance_start) != num_extra_conditions:
693
+ raise ValueError(
694
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {num_extra_conditions} extra conditions available. Make sure to provide {num_extra_conditions}."
695
+ )
696
+
697
+ for start, end in zip(control_guidance_start, control_guidance_end):
698
+ if start >= end:
699
+ raise ValueError(
700
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
701
+ )
702
+ if start < 0.0:
703
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
704
+ if end > 1.0:
705
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
706
+
707
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
708
+ raise ValueError(
709
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
710
+ )
711
+
712
+ if ip_adapter_image_embeds is not None:
713
+ if not isinstance(ip_adapter_image_embeds, list):
714
+ raise ValueError(
715
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
716
+ )
717
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
718
+ raise ValueError(
719
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
720
+ )
721
+
722
+ def check_image(self, image, prompt, prompt_embeds):
723
+ image_is_pil = isinstance(image, PIL.Image.Image)
724
+ image_is_tensor = isinstance(image, torch.Tensor)
725
+ image_is_np = isinstance(image, np.ndarray)
726
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
727
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
728
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
729
+
730
+ if (
731
+ not image_is_pil
732
+ and not image_is_tensor
733
+ and not image_is_np
734
+ and not image_is_pil_list
735
+ and not image_is_tensor_list
736
+ and not image_is_np_list
737
+ ):
738
+ raise TypeError(
739
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
740
+ )
741
+
742
+ if image_is_pil:
743
+ image_batch_size = 1
744
+ else:
745
+ image_batch_size = len(image)
746
+
747
+ if prompt is not None and isinstance(prompt, str):
748
+ prompt_batch_size = 1
749
+ elif prompt is not None and isinstance(prompt, list):
750
+ prompt_batch_size = len(prompt)
751
+ elif prompt_embeds is not None:
752
+ prompt_batch_size = prompt_embeds.shape[0]
753
+
754
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
755
+ raise ValueError(
756
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
757
+ )
758
+
759
+ def prepare_image(
760
+ self,
761
+ image,
762
+ width,
763
+ height,
764
+ batch_size,
765
+ num_images_per_prompt,
766
+ device,
767
+ dtype,
768
+ do_classifier_free_guidance=False,
769
+ ):
770
+ image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
771
+ image_batch_size = image.shape[0]
772
+
773
+ if image_batch_size == 1:
774
+ repeat_by = batch_size
775
+ else:
776
+ # image batch size is the same as prompt batch size
777
+ repeat_by = num_images_per_prompt
778
+
779
+ image = image.repeat_interleave(repeat_by, dim=0)
780
+
781
+ image = image.to(device=device, dtype=dtype)
782
+
783
+ if do_classifier_free_guidance:
784
+ image = torch.cat([image] * 2)
785
+
786
+ return image
787
+
788
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
789
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
790
+ shape = (
791
+ batch_size,
792
+ num_channels_latents,
793
+ int(height) // self.vae_scale_factor,
794
+ int(width) // self.vae_scale_factor,
795
+ )
796
+ if isinstance(generator, list) and len(generator) != batch_size:
797
+ raise ValueError(
798
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
799
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
800
+ )
801
+
802
+ if latents is None:
803
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
804
+ else:
805
+ latents = latents.to(device)
806
+
807
+ # scale the initial noise by the standard deviation required by the scheduler
808
+ latents = latents * self.scheduler.init_noise_sigma
809
+ return latents
810
+
811
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
812
+ def get_guidance_scale_embedding(
813
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
814
+ ) -> torch.Tensor:
815
+ """
816
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
817
+
818
+ Args:
819
+ w (`torch.Tensor`):
820
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
821
+ embedding_dim (`int`, *optional*, defaults to 512):
822
+ Dimension of the embeddings to generate.
823
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
824
+ Data type of the generated embeddings.
825
+
826
+ Returns:
827
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
828
+ """
829
+ assert len(w.shape) == 1
830
+ w = w * 1000.0
831
+
832
+ half_dim = embedding_dim // 2
833
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
834
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
835
+ emb = w.to(dtype)[:, None] * emb[None, :]
836
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
837
+ if embedding_dim % 2 == 1: # zero pad
838
+ emb = torch.nn.functional.pad(emb, (0, 1))
839
+ assert emb.shape == (w.shape[0], embedding_dim)
840
+ return emb
841
+
842
+ @property
843
+ def guidance_scale(self):
844
+ return self._guidance_scale
845
+
846
+ @property
847
+ def clip_skip(self):
848
+ return self._clip_skip
849
+
850
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
851
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
852
+ # corresponds to doing no classifier free guidance.
853
+ @property
854
+ def do_classifier_free_guidance(self):
855
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
856
+
857
+ @property
858
+ def cross_attention_kwargs(self):
859
+ return self._cross_attention_kwargs
860
+
861
+ @property
862
+ def num_timesteps(self):
863
+ return self._num_timesteps
864
+
865
+ @classmethod
866
+ @validate_hf_hub_args
867
+ def lora_state_dict(
868
+ cls,
869
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
870
+ **kwargs,
871
+ ):
872
+ # Override to add support for different LoRA alphas
873
+ state_dict, network_alphas = super(StableDiffusionControlLoraV3Pipeline, cls).lora_state_dict(
874
+ pretrained_model_name_or_path_or_dict, **kwargs
875
+ )
876
+ if network_alphas is None:
877
+ network_alphas = {}
878
+ for k, v in state_dict.items():
879
+ if ".lora_A." in k:
880
+ network_alphas[".".join(k.split(".lora_A.")[0].split(".") + ["alpha"])] = v.shape[0]
881
+ return state_dict, network_alphas
882
+
883
+ def load_lora_weights(
884
+ self,
885
+ pretrained_model_name_or_path_or_dict: Union[
886
+ Union[str, Dict[str, torch.Tensor]],
887
+ List[Union[str, Dict[str, torch.Tensor]]]
888
+ ],
889
+ adapter_name=None,
890
+ **kwargs
891
+ ):
892
+ unet: UNet2DConditionModelEx = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
893
+ num_condition_names = len(unet.extra_condition_names)
894
+ in_channels = unet.config.in_channels
895
+
896
+ kwargs["weight_name"] = kwargs.pop("weight_name", "pytorch_lora_weights.safetensors")
897
+
898
+ if adapter_name is not None and adapter_name not in unet.extra_condition_names:
899
+ unet._hf_peft_config_loaded = True
900
+ super().load_lora_weights(pretrained_model_name_or_path_or_dict, adapter_name, **kwargs)
901
+ unet.set_adapter(adapter_name)
902
+ return
903
+
904
+ if not isinstance(pretrained_model_name_or_path_or_dict, list):
905
+ pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict] * num_condition_names
906
+ pretrained_model_name_or_path_or_dict_list = pretrained_model_name_or_path_or_dict
907
+
908
+ assert len(pretrained_model_name_or_path_or_dict) == len(unet.extra_condition_names)
909
+
910
+ adapter_name_ori = adapter_name
911
+ for i, (pretrained_model_name_or_path_or_dict, adapter_name) in enumerate(zip(
912
+ pretrained_model_name_or_path_or_dict_list,
913
+ unet.extra_condition_names
914
+ )):
915
+ _kwargs = {**kwargs}
916
+ subfolder = _kwargs.pop("subfolder", None)
917
+ if isinstance(subfolder, list):
918
+ subfolder = subfolder[i]
919
+
920
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
921
+ pretrained_model_name_or_path_or_dict, _ = self.lora_state_dict(
922
+ pretrained_model_name_or_path_or_dict,
923
+ subfolder=subfolder,
924
+ **_kwargs
925
+ )
926
+
927
+ if adapter_name_ori is not None:
928
+ # only load lora of the input adapter name, then break the loop
929
+ i = unet.extra_condition_names.index(adapter_name_ori)
930
+ adapter_name = adapter_name_ori
931
+
932
+ unet_conv_in_lora_A_name, old_weight = ([
933
+ (k, v)
934
+ for k, v in pretrained_model_name_or_path_or_dict.items()
935
+ if "unet." in k and ".conv_in." in k and ".lora_A." in k
936
+ ] + [(None, None)])[0]
937
+ if unet_conv_in_lora_A_name is not None:
938
+ in_weight = old_weight[:,:in_channels]
939
+ cond_weight = old_weight[:,in_channels:]
940
+ zero_weight = torch.zeros_like(in_weight)
941
+ new_weight = torch.cat(
942
+ [in_weight] +
943
+ [zero_weight] * i +
944
+ [cond_weight] +
945
+ [zero_weight] * (num_condition_names - i - 1),
946
+ dim=1
947
+ )
948
+ pretrained_model_name_or_path_or_dict[unet_conv_in_lora_A_name] = new_weight
949
+
950
+ super().load_lora_weights(pretrained_model_name_or_path_or_dict, adapter_name, **_kwargs)
951
+
952
+ if adapter_name_ori is not None:
953
+ break
954
+
955
+ unet.activate_extra_condition_adapters()
956
+
957
+ @torch.no_grad()
958
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
959
+ def __call__(
960
+ self,
961
+ prompt: Union[str, List[str]] = None,
962
+ image: PipelineImageInput = None,
963
+ height: Optional[int] = None,
964
+ width: Optional[int] = None,
965
+ num_inference_steps: int = 50,
966
+ timesteps: List[int] = None,
967
+ sigmas: List[float] = None,
968
+ guidance_scale: float = 7.5,
969
+ negative_prompt: Optional[Union[str, List[str]]] = None,
970
+ num_images_per_prompt: Optional[int] = 1,
971
+ eta: float = 0.0,
972
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
973
+ latents: Optional[torch.Tensor] = None,
974
+ prompt_embeds: Optional[torch.Tensor] = None,
975
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
976
+ ip_adapter_image: Optional[PipelineImageInput] = None,
977
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
978
+ output_type: Optional[str] = "pil",
979
+ return_dict: bool = True,
980
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
981
+ extra_condition_scale: Union[float, List[float]] = 1.0,
982
+ control_guidance_start: Union[float, List[float]] = 0.0,
983
+ control_guidance_end: Union[float, List[float]] = 1.0,
984
+ clip_skip: Optional[int] = None,
985
+ callback_on_step_end: Optional[
986
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
987
+ ] = None,
988
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
989
+ **kwargs,
990
+ ):
991
+ r"""
992
+ The call function to the pipeline for generation.
993
+
994
+ Args:
995
+ prompt (`str` or `List[str]`, *optional*):
996
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
997
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
998
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
999
+ The extra input condition to provide guidance to the `unet` for generation after encoded by `vae`. If the type is
1000
+ specified as `torch.Tensor`, its `vae` latent representation is passed to UNet. `PIL.Image.Image` can also be accepted
1001
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
1002
+ width are passed, `image` is resized accordingly. If multiple extra conditions are specified in `unet`,
1003
+ images must be passed as a list such that each element of the list can be correctly batched for input
1004
+ to `unet`. When `prompt` is a list, and if a list of images is passed for `unet`, each will be paired with each prompt
1005
+ in the `prompt` list. This also applies to multiple extra conditions, where a list of image lists can be
1006
+ passed to batch for each prompt and each extra condition.
1007
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1008
+ The height in pixels of the generated image.
1009
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1010
+ The width in pixels of the generated image.
1011
+ num_inference_steps (`int`, *optional*, defaults to 50):
1012
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1013
+ expense of slower inference.
1014
+ timesteps (`List[int]`, *optional*):
1015
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1016
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1017
+ passed will be used. Must be in descending order.
1018
+ sigmas (`List[float]`, *optional*):
1019
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
1020
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
1021
+ will be used.
1022
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1023
+ A higher guidance scale value encourages the model to generate images closely linked to the text
1024
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
1025
+ negative_prompt (`str` or `List[str]`, *optional*):
1026
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
1027
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
1028
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1029
+ The number of images to generate per prompt.
1030
+ eta (`float`, *optional*, defaults to 0.0):
1031
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
1032
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
1033
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1034
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1035
+ generation deterministic.
1036
+ latents (`torch.Tensor`, *optional*):
1037
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
1038
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1039
+ tensor is generated by sampling using the supplied random `generator`.
1040
+ prompt_embeds (`torch.Tensor`, *optional*):
1041
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
1042
+ provided, text embeddings are generated from the `prompt` input argument.
1043
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
1044
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1045
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
1046
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1047
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
1048
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1049
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
1050
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1051
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1052
+ output_type (`str`, *optional*, defaults to `"pil"`):
1053
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1054
+ return_dict (`bool`, *optional*, defaults to `True`):
1055
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1056
+ plain tuple.
1057
+ callback (`Callable`, *optional*):
1058
+ A function that calls every `callback_steps` steps during inference. The function is called with the
1059
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
1060
+ callback_steps (`int`, *optional*, defaults to 1):
1061
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
1062
+ every step.
1063
+ cross_attention_kwargs (`dict`, *optional*):
1064
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1065
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1066
+ extra_condition_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
1067
+ The control lora scale of `unet`. If multiple extra conditions are specified in `unet`, you can set
1068
+ the corresponding scale as a list.
1069
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
1070
+ The percentage of total steps at which the extra condtion starts applying.
1071
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1072
+ The percentage of total steps at which the extra condtion stops applying.
1073
+ clip_skip (`int`, *optional*):
1074
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1075
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1076
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1077
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1078
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1079
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1080
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1081
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1082
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1083
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1084
+ `._callback_tensor_inputs` attribute of your pipeline class.
1085
+
1086
+ Examples:
1087
+
1088
+ Returns:
1089
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1090
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
1091
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
1092
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
1093
+ "not-safe-for-work" (nsfw) content.
1094
+ """
1095
+
1096
+ callback = kwargs.pop("callback", None)
1097
+ callback_steps = kwargs.pop("callback_steps", None)
1098
+
1099
+ if callback is not None:
1100
+ deprecate(
1101
+ "callback",
1102
+ "1.0.0",
1103
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1104
+ )
1105
+ if callback_steps is not None:
1106
+ deprecate(
1107
+ "callback_steps",
1108
+ "1.0.0",
1109
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1110
+ )
1111
+
1112
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1113
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1114
+
1115
+ unet: UNet2DConditionModelEx = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
1116
+ num_extra_conditions = len(unet.extra_condition_names)
1117
+
1118
+ # align format for control guidance
1119
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1120
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1121
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1122
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1123
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1124
+ mult = num_extra_conditions
1125
+ control_guidance_start, control_guidance_end = (
1126
+ mult * [control_guidance_start],
1127
+ mult * [control_guidance_end],
1128
+ )
1129
+
1130
+ # 1. Check inputs. Raise error if not correct
1131
+ self.check_inputs(
1132
+ prompt,
1133
+ image,
1134
+ callback_steps,
1135
+ negative_prompt,
1136
+ prompt_embeds,
1137
+ negative_prompt_embeds,
1138
+ ip_adapter_image,
1139
+ ip_adapter_image_embeds,
1140
+ extra_condition_scale,
1141
+ control_guidance_start,
1142
+ control_guidance_end,
1143
+ callback_on_step_end_tensor_inputs,
1144
+ )
1145
+
1146
+ self._guidance_scale = guidance_scale
1147
+ self._clip_skip = clip_skip
1148
+ self._cross_attention_kwargs = cross_attention_kwargs
1149
+
1150
+ # 2. Define call parameters
1151
+ if prompt is not None and isinstance(prompt, str):
1152
+ batch_size = 1
1153
+ elif prompt is not None and isinstance(prompt, list):
1154
+ batch_size = len(prompt)
1155
+ else:
1156
+ batch_size = prompt_embeds.shape[0]
1157
+
1158
+ device = self._execution_device
1159
+
1160
+ if num_extra_conditions > 1 and isinstance(extra_condition_scale, float):
1161
+ extra_condition_scale = [extra_condition_scale] * num_extra_conditions
1162
+
1163
+ # 3. Encode input prompt
1164
+ text_encoder_lora_scale = (
1165
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1166
+ )
1167
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
1168
+ prompt,
1169
+ device,
1170
+ num_images_per_prompt,
1171
+ self.do_classifier_free_guidance,
1172
+ negative_prompt,
1173
+ prompt_embeds=prompt_embeds,
1174
+ negative_prompt_embeds=negative_prompt_embeds,
1175
+ lora_scale=text_encoder_lora_scale,
1176
+ clip_skip=self.clip_skip,
1177
+ )
1178
+ # For classifier free guidance, we need to do two forward passes.
1179
+ # Here we concatenate the unconditional and text embeddings into a single batch
1180
+ # to avoid doing two forward passes
1181
+ if self.do_classifier_free_guidance:
1182
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1183
+
1184
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1185
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1186
+ ip_adapter_image,
1187
+ ip_adapter_image_embeds,
1188
+ device,
1189
+ batch_size * num_images_per_prompt,
1190
+ self.do_classifier_free_guidance,
1191
+ )
1192
+
1193
+ # 4. Prepare image
1194
+ if num_extra_conditions == 1:
1195
+ image = self.prepare_image(
1196
+ image=image,
1197
+ width=width,
1198
+ height=height,
1199
+ batch_size=batch_size * num_images_per_prompt,
1200
+ num_images_per_prompt=num_images_per_prompt,
1201
+ device=device,
1202
+ dtype=unet.dtype,
1203
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1204
+ )
1205
+ height, width = image.shape[-2:]
1206
+ image = (
1207
+ self.vae.encode(image.to(dtype=unet.dtype)).latent_dist.mode() * self.vae.config.scaling_factor
1208
+ )
1209
+ elif num_extra_conditions >= 1:
1210
+ images = []
1211
+
1212
+ # Nested lists as extra condition
1213
+ if isinstance(image[0], list):
1214
+ # Transpose the nested image list
1215
+ image = [list(t) for t in zip(*image)]
1216
+
1217
+ for image_ in image:
1218
+ image_ = self.prepare_image(
1219
+ image=image_,
1220
+ width=width,
1221
+ height=height,
1222
+ batch_size=batch_size * num_images_per_prompt,
1223
+ num_images_per_prompt=num_images_per_prompt,
1224
+ device=device,
1225
+ dtype=unet.dtype,
1226
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1227
+ )
1228
+
1229
+ images.append(image_)
1230
+
1231
+ image = images
1232
+ height, width = image[0].shape[-2:]
1233
+ image = [
1234
+ self.vae.encode(image.to(dtype=unet.dtype)).latent_dist.mode() * self.vae.config.scaling_factor
1235
+ for image in images
1236
+ ]
1237
+ else:
1238
+ assert False
1239
+
1240
+ # 5. Prepare timesteps
1241
+ timesteps, num_inference_steps = retrieve_timesteps(
1242
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1243
+ )
1244
+ self._num_timesteps = len(timesteps)
1245
+
1246
+ # 6. Prepare latent variables
1247
+ num_channels_latents = self.unet.config.in_channels
1248
+ latents = self.prepare_latents(
1249
+ batch_size * num_images_per_prompt,
1250
+ num_channels_latents,
1251
+ height,
1252
+ width,
1253
+ prompt_embeds.dtype,
1254
+ device,
1255
+ generator,
1256
+ latents,
1257
+ )
1258
+
1259
+ # 6.5 Optionally get Guidance Scale Embedding
1260
+ timestep_cond = None
1261
+ if self.unet.config.time_cond_proj_dim is not None:
1262
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1263
+ timestep_cond = self.get_guidance_scale_embedding(
1264
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1265
+ ).to(device=device, dtype=latents.dtype)
1266
+
1267
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1268
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1269
+
1270
+ # 7.1 Add image embeds for IP-Adapter
1271
+ added_cond_kwargs = (
1272
+ {"image_embeds": image_embeds}
1273
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
1274
+ else None
1275
+ )
1276
+
1277
+ # 7.2 Create tensor stating which extra_conditions to keep
1278
+ extra_condition_keep = []
1279
+ for i in range(len(timesteps)):
1280
+ keeps = [
1281
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1282
+ for s, e in zip(control_guidance_start, control_guidance_end)
1283
+ ]
1284
+ extra_condition_keep.append(keeps[0] if num_extra_conditions == 1 else keeps)
1285
+
1286
+ # 8. Denoising loop
1287
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1288
+ is_unet_compiled = is_compiled_module(self.unet)
1289
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1290
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1291
+ for i, t in enumerate(timesteps):
1292
+ # Relevant thread:
1293
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1294
+ if is_unet_compiled and is_torch_higher_equal_2_1:
1295
+ torch._inductor.cudagraph_mark_step_begin()
1296
+ # expand the latents if we are doing classifier free guidance
1297
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1298
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1299
+
1300
+ if isinstance(extra_condition_keep[i], list):
1301
+ cond_scale = [c * s for c, s in zip(extra_condition_scale, extra_condition_keep[i])]
1302
+ else:
1303
+ extra_cond_scale = extra_condition_scale
1304
+ if isinstance(extra_cond_scale, list):
1305
+ extra_cond_scale = extra_cond_scale[0]
1306
+ cond_scale = extra_cond_scale * extra_condition_keep[i]
1307
+
1308
+ self.unet.set_extra_condition_scale(cond_scale)
1309
+
1310
+ # predict the noise residual
1311
+ noise_pred = self.unet(
1312
+ latent_model_input,
1313
+ t,
1314
+ encoder_hidden_states=prompt_embeds,
1315
+ timestep_cond=timestep_cond,
1316
+ cross_attention_kwargs=self.cross_attention_kwargs,
1317
+ added_cond_kwargs=added_cond_kwargs,
1318
+ extra_conditions=image,
1319
+ return_dict=False,
1320
+ )[0]
1321
+
1322
+ # perform guidance
1323
+ if self.do_classifier_free_guidance:
1324
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1325
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1326
+
1327
+ # compute the previous noisy sample x_t -> x_t-1
1328
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1329
+
1330
+ if callback_on_step_end is not None:
1331
+ callback_kwargs = {}
1332
+ for k in callback_on_step_end_tensor_inputs:
1333
+ callback_kwargs[k] = locals()[k]
1334
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1335
+
1336
+ latents = callback_outputs.pop("latents", latents)
1337
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1338
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1339
+
1340
+ # call the callback, if provided
1341
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1342
+ progress_bar.update()
1343
+ if callback is not None and i % callback_steps == 0:
1344
+ step_idx = i // getattr(self.scheduler, "order", 1)
1345
+ callback(step_idx, t, latents)
1346
+
1347
+ self.unet.set_extra_condition_scale(1.0)
1348
+
1349
+ # If we do sequential model offloading, let's offload unet
1350
+ # manually for max memory savings
1351
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1352
+ self.unet.to("cpu")
1353
+ torch.cuda.empty_cache()
1354
+
1355
+ if not output_type == "latent":
1356
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1357
+ 0
1358
+ ]
1359
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1360
+ else:
1361
+ image = latents
1362
+ has_nsfw_concept = None
1363
+
1364
+ if has_nsfw_concept is None:
1365
+ do_denormalize = [True] * image.shape[0]
1366
+ else:
1367
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1368
+
1369
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1370
+
1371
+ # Offload all models
1372
+ self.maybe_free_model_hooks()
1373
+
1374
+ if not return_dict:
1375
+ return (image, has_nsfw_concept)
1376
+
1377
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)