chriswu25 commited on
Commit
b4060af
·
verified ·
1 Parent(s): 3fcfa08

Update src/layers_cache.py

Browse files
Files changed (1) hide show
  1. src/layers_cache.py +367 -367
src/layers_cache.py CHANGED
@@ -1,368 +1,368 @@
1
- import inspect
2
- import math
3
- from typing import Callable, List, Optional, Tuple, Union
4
- from einops import rearrange
5
- import torch
6
- from torch import nn
7
- import torch.nn.functional as F
8
- from torch import Tensor
9
- from diffusers.models.attention_processor import Attention
10
-
11
- class LoRALinearLayer(nn.Module):
12
- def __init__(
13
- self,
14
- in_features: int,
15
- out_features: int,
16
- rank: int = 4,
17
- network_alpha: Optional[float] = None,
18
- device: Optional[Union[torch.device, str]] = None,
19
- dtype: Optional[torch.dtype] = None,
20
- cond_width=512,
21
- cond_height=512,
22
- number=0,
23
- n_loras=1
24
- ):
25
- super().__init__()
26
- self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
27
- self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
28
- # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
29
- # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
30
- self.network_alpha = network_alpha
31
- self.rank = rank
32
- self.out_features = out_features
33
- self.in_features = in_features
34
-
35
- nn.init.normal_(self.down.weight, std=1 / rank)
36
- nn.init.zeros_(self.up.weight)
37
-
38
- self.cond_height = cond_height
39
- self.cond_width = cond_width
40
- self.number = number
41
- self.n_loras = n_loras
42
-
43
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
44
- orig_dtype = hidden_states.dtype
45
- dtype = self.down.weight.dtype
46
-
47
- ####
48
- batch_size = hidden_states.shape[0]
49
- cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
50
- block_size = hidden_states.shape[1] - cond_size * self.n_loras
51
- shape = (batch_size, hidden_states.shape[1], 3072)
52
- mask = torch.ones(shape, device=hidden_states.device, dtype=dtype)
53
- mask[:, :block_size+self.number*cond_size, :] = 0
54
- mask[:, block_size+(self.number+1)*cond_size:, :] = 0
55
- hidden_states = mask * hidden_states
56
- ####
57
-
58
- down_hidden_states = self.down(hidden_states.to(dtype))
59
- up_hidden_states = self.up(down_hidden_states)
60
-
61
- if self.network_alpha is not None:
62
- up_hidden_states *= self.network_alpha / self.rank
63
-
64
- return up_hidden_states.to(orig_dtype)
65
-
66
-
67
- class MultiSingleStreamBlockLoraProcessor(nn.Module):
68
- def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1):
69
- super().__init__()
70
- # Initialize a list to store the LoRA layers
71
- self.n_loras = n_loras
72
- self.cond_width = cond_width
73
- self.cond_height = cond_height
74
-
75
- self.q_loras = nn.ModuleList([
76
- LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
77
- for i in range(n_loras)
78
- ])
79
- self.k_loras = nn.ModuleList([
80
- LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
81
- for i in range(n_loras)
82
- ])
83
- self.v_loras = nn.ModuleList([
84
- LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
85
- for i in range(n_loras)
86
- ])
87
- self.lora_weights = lora_weights
88
- self.bank_attn = None
89
- self.bank_kv = []
90
-
91
-
92
- def __call__(self,
93
- attn: Attention,
94
- hidden_states: torch.FloatTensor,
95
- encoder_hidden_states: torch.FloatTensor = None,
96
- attention_mask: Optional[torch.FloatTensor] = None,
97
- image_rotary_emb: Optional[torch.Tensor] = None,
98
- use_cond = False
99
- ) -> torch.FloatTensor:
100
-
101
- batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
102
- scaled_seq_len = hidden_states.shape[1]
103
- cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
104
- block_size = scaled_seq_len - cond_size * self.n_loras
105
- scaled_cond_size = cond_size
106
- scaled_block_size = block_size
107
-
108
- if len(self.bank_kv)== 0:
109
- cache = True
110
- else:
111
- cache = False
112
-
113
- if cache:
114
- query = attn.to_q(hidden_states)
115
- key = attn.to_k(hidden_states)
116
- value = attn.to_v(hidden_states)
117
- for i in range(self.n_loras):
118
- query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
119
- key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
120
- value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
121
-
122
- inner_dim = key.shape[-1]
123
- head_dim = inner_dim // attn.heads
124
-
125
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
126
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
127
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
128
-
129
- self.bank_kv.append(key[:, :, scaled_block_size:, :])
130
- self.bank_kv.append(value[:, :, scaled_block_size:, :])
131
-
132
- if attn.norm_q is not None:
133
- query = attn.norm_q(query)
134
- if attn.norm_k is not None:
135
- key = attn.norm_k(key)
136
-
137
- if image_rotary_emb is not None:
138
- from diffusers.models.embeddings import apply_rotary_emb
139
- query = apply_rotary_emb(query, image_rotary_emb)
140
- key = apply_rotary_emb(key, image_rotary_emb)
141
-
142
- num_cond_blocks = self.n_loras
143
- mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
144
- mask[ :scaled_block_size, :] = 0 # First block_size row
145
- for i in range(num_cond_blocks):
146
- start = i * scaled_cond_size + scaled_block_size
147
- end = (i + 1) * scaled_cond_size + scaled_block_size
148
- mask[start:end, start:end] = 0 # Diagonal blocks
149
- mask = mask * -1e10
150
- mask = mask.to(query.dtype)
151
-
152
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
153
- self.bank_attn = hidden_states[:, :, scaled_block_size:, :]
154
-
155
- else:
156
- query = attn.to_q(hidden_states)
157
- key = attn.to_k(hidden_states)
158
- value = attn.to_v(hidden_states)
159
-
160
- inner_dim = query.shape[-1]
161
- head_dim = inner_dim // attn.heads
162
-
163
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
164
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
165
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
166
-
167
- key = torch.concat([key[:, :, :scaled_block_size, :], self.bank_kv[0]], dim=-2)
168
- value = torch.concat([value[:, :, :scaled_block_size, :], self.bank_kv[1]], dim=-2)
169
-
170
- if attn.norm_q is not None:
171
- query = attn.norm_q(query)
172
- if attn.norm_k is not None:
173
- key = attn.norm_k(key)
174
-
175
- if image_rotary_emb is not None:
176
- from diffusers.models.embeddings import apply_rotary_emb
177
- query = apply_rotary_emb(query, image_rotary_emb)
178
- key = apply_rotary_emb(key, image_rotary_emb)
179
-
180
- query = query[:, :, :scaled_block_size, :]
181
-
182
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=None)
183
- hidden_states = torch.concat([hidden_states, self.bank_attn], dim=-2)
184
-
185
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
186
- hidden_states = hidden_states.to(query.dtype)
187
-
188
- cond_hidden_states = hidden_states[:, block_size:,:]
189
- hidden_states = hidden_states[:, : block_size,:]
190
-
191
- return hidden_states if not use_cond else (hidden_states, cond_hidden_states)
192
-
193
-
194
- class MultiDoubleStreamBlockLoraProcessor(nn.Module):
195
- def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1):
196
- super().__init__()
197
-
198
- # Initialize a list to store the LoRA layers
199
- self.n_loras = n_loras
200
- self.cond_width = cond_width
201
- self.cond_height = cond_height
202
- self.q_loras = nn.ModuleList([
203
- LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
204
- for i in range(n_loras)
205
- ])
206
- self.k_loras = nn.ModuleList([
207
- LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
208
- for i in range(n_loras)
209
- ])
210
- self.v_loras = nn.ModuleList([
211
- LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
212
- for i in range(n_loras)
213
- ])
214
- self.proj_loras = nn.ModuleList([
215
- LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
216
- for i in range(n_loras)
217
- ])
218
- self.lora_weights = lora_weights
219
- self.bank_attn = None
220
- self.bank_kv = []
221
-
222
-
223
- def __call__(self,
224
- attn: Attention,
225
- hidden_states: torch.FloatTensor,
226
- encoder_hidden_states: torch.FloatTensor = None,
227
- attention_mask: Optional[torch.FloatTensor] = None,
228
- image_rotary_emb: Optional[torch.Tensor] = None,
229
- use_cond=False,
230
- ) -> torch.FloatTensor:
231
-
232
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
233
- cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
234
- block_size = hidden_states.shape[1] - cond_size * self.n_loras
235
- scaled_seq_len = encoder_hidden_states.shape[1] + hidden_states.shape[1]
236
- scaled_cond_size = cond_size
237
- scaled_block_size = scaled_seq_len - scaled_cond_size * self.n_loras
238
-
239
- # `context` projections.
240
- inner_dim = 3072
241
- head_dim = inner_dim // attn.heads
242
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
243
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
244
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
245
-
246
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
247
- batch_size, -1, attn.heads, head_dim
248
- ).transpose(1, 2)
249
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
250
- batch_size, -1, attn.heads, head_dim
251
- ).transpose(1, 2)
252
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
253
- batch_size, -1, attn.heads, head_dim
254
- ).transpose(1, 2)
255
-
256
- if attn.norm_added_q is not None:
257
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
258
- if attn.norm_added_k is not None:
259
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
260
-
261
- if len(self.bank_kv)== 0:
262
- cache = True
263
- else:
264
- cache = False
265
-
266
- if cache:
267
-
268
- query = attn.to_q(hidden_states)
269
- key = attn.to_k(hidden_states)
270
- value = attn.to_v(hidden_states)
271
- for i in range(self.n_loras):
272
- query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
273
- key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
274
- value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
275
-
276
- inner_dim = key.shape[-1]
277
- head_dim = inner_dim // attn.heads
278
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
279
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
280
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
281
-
282
-
283
- self.bank_kv.append(key[:, :, block_size:, :])
284
- self.bank_kv.append(value[:, :, block_size:, :])
285
-
286
- if attn.norm_q is not None:
287
- query = attn.norm_q(query)
288
- if attn.norm_k is not None:
289
- key = attn.norm_k(key)
290
-
291
- # attention
292
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
293
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
294
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
295
-
296
- if image_rotary_emb is not None:
297
- from diffusers.models.embeddings import apply_rotary_emb
298
- query = apply_rotary_emb(query, image_rotary_emb)
299
- key = apply_rotary_emb(key, image_rotary_emb)
300
-
301
- num_cond_blocks = self.n_loras
302
- mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
303
- mask[ :scaled_block_size, :] = 0 # First block_size row
304
- for i in range(num_cond_blocks):
305
- start = i * scaled_cond_size + scaled_block_size
306
- end = (i + 1) * scaled_cond_size + scaled_block_size
307
- mask[start:end, start:end] = 0 # Diagonal blocks
308
- mask = mask * -1e10
309
- mask = mask.to(query.dtype)
310
-
311
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
312
- self.bank_attn = hidden_states[:, :, scaled_block_size:, :]
313
-
314
- else:
315
- query = attn.to_q(hidden_states)
316
- key = attn.to_k(hidden_states)
317
- value = attn.to_v(hidden_states)
318
-
319
- inner_dim = query.shape[-1]
320
- head_dim = inner_dim // attn.heads
321
-
322
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
323
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
324
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
325
-
326
- key = torch.concat([key[:, :, :block_size, :], self.bank_kv[0]], dim=-2)
327
- value = torch.concat([value[:, :, :block_size, :], self.bank_kv[1]], dim=-2)
328
-
329
- if attn.norm_q is not None:
330
- query = attn.norm_q(query)
331
- if attn.norm_k is not None:
332
- key = attn.norm_k(key)
333
-
334
- # attention
335
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
336
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
337
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
338
-
339
- if image_rotary_emb is not None:
340
- from diffusers.models.embeddings import apply_rotary_emb
341
- query = apply_rotary_emb(query, image_rotary_emb)
342
- key = apply_rotary_emb(key, image_rotary_emb)
343
-
344
- query = query[:, :, :scaled_block_size, :]
345
-
346
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=None)
347
- hidden_states = torch.concat([hidden_states, self.bank_attn], dim=-2)
348
-
349
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
350
- hidden_states = hidden_states.to(query.dtype)
351
-
352
- encoder_hidden_states, hidden_states = (
353
- hidden_states[:, : encoder_hidden_states.shape[1]],
354
- hidden_states[:, encoder_hidden_states.shape[1] :],
355
- )
356
-
357
- # Linear projection (with LoRA weight applied to each proj layer)
358
- hidden_states = attn.to_out[0](hidden_states)
359
- for i in range(self.n_loras):
360
- hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
361
- # dropout
362
- hidden_states = attn.to_out[1](hidden_states)
363
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
364
-
365
- cond_hidden_states = hidden_states[:, block_size:,:]
366
- hidden_states = hidden_states[:, :block_size,:]
367
-
368
  return (hidden_states, encoder_hidden_states, cond_hidden_states) if use_cond else (encoder_hidden_states, hidden_states)
 
1
+ import inspect
2
+ import math
3
+ from typing import Callable, List, Optional, Tuple, Union
4
+ from einops import rearrange
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+ from diffusers.models.attention_processor import Attention
10
+
11
+ class LoRALinearLayer(nn.Module):
12
+ def __init__(
13
+ self,
14
+ in_features: int,
15
+ out_features: int,
16
+ rank: int = 4,
17
+ network_alpha: Optional[float] = None,
18
+ device: Optional[Union[torch.device, str]] = "cpu",
19
+ dtype: Optional[torch.dtype] = None,
20
+ cond_width=512,
21
+ cond_height=512,
22
+ number=0,
23
+ n_loras=1
24
+ ):
25
+ super().__init__()
26
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
27
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
28
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
29
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
30
+ self.network_alpha = network_alpha
31
+ self.rank = rank
32
+ self.out_features = out_features
33
+ self.in_features = in_features
34
+
35
+ nn.init.normal_(self.down.weight, std=1 / rank)
36
+ nn.init.zeros_(self.up.weight)
37
+
38
+ self.cond_height = cond_height
39
+ self.cond_width = cond_width
40
+ self.number = number
41
+ self.n_loras = n_loras
42
+
43
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
44
+ orig_dtype = hidden_states.dtype
45
+ dtype = self.down.weight.dtype
46
+
47
+ ####
48
+ batch_size = hidden_states.shape[0]
49
+ cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
50
+ block_size = hidden_states.shape[1] - cond_size * self.n_loras
51
+ shape = (batch_size, hidden_states.shape[1], 3072)
52
+ mask = torch.ones(shape, device=hidden_states.device, dtype=dtype)
53
+ mask[:, :block_size+self.number*cond_size, :] = 0
54
+ mask[:, block_size+(self.number+1)*cond_size:, :] = 0
55
+ hidden_states = mask * hidden_states
56
+ ####
57
+
58
+ down_hidden_states = self.down(hidden_states.to(dtype))
59
+ up_hidden_states = self.up(down_hidden_states)
60
+
61
+ if self.network_alpha is not None:
62
+ up_hidden_states *= self.network_alpha / self.rank
63
+
64
+ return up_hidden_states.to(orig_dtype)
65
+
66
+
67
+ class MultiSingleStreamBlockLoraProcessor(nn.Module):
68
+ def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1):
69
+ super().__init__()
70
+ # Initialize a list to store the LoRA layers
71
+ self.n_loras = n_loras
72
+ self.cond_width = cond_width
73
+ self.cond_height = cond_height
74
+
75
+ self.q_loras = nn.ModuleList([
76
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
77
+ for i in range(n_loras)
78
+ ])
79
+ self.k_loras = nn.ModuleList([
80
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
81
+ for i in range(n_loras)
82
+ ])
83
+ self.v_loras = nn.ModuleList([
84
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
85
+ for i in range(n_loras)
86
+ ])
87
+ self.lora_weights = lora_weights
88
+ self.bank_attn = None
89
+ self.bank_kv = []
90
+
91
+
92
+ def __call__(self,
93
+ attn: Attention,
94
+ hidden_states: torch.FloatTensor,
95
+ encoder_hidden_states: torch.FloatTensor = None,
96
+ attention_mask: Optional[torch.FloatTensor] = None,
97
+ image_rotary_emb: Optional[torch.Tensor] = None,
98
+ use_cond = False
99
+ ) -> torch.FloatTensor:
100
+
101
+ batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
102
+ scaled_seq_len = hidden_states.shape[1]
103
+ cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
104
+ block_size = scaled_seq_len - cond_size * self.n_loras
105
+ scaled_cond_size = cond_size
106
+ scaled_block_size = block_size
107
+
108
+ if len(self.bank_kv)== 0:
109
+ cache = True
110
+ else:
111
+ cache = False
112
+
113
+ if cache:
114
+ query = attn.to_q(hidden_states)
115
+ key = attn.to_k(hidden_states)
116
+ value = attn.to_v(hidden_states)
117
+ for i in range(self.n_loras):
118
+ query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
119
+ key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
120
+ value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
121
+
122
+ inner_dim = key.shape[-1]
123
+ head_dim = inner_dim // attn.heads
124
+
125
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
126
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
127
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
128
+
129
+ self.bank_kv.append(key[:, :, scaled_block_size:, :])
130
+ self.bank_kv.append(value[:, :, scaled_block_size:, :])
131
+
132
+ if attn.norm_q is not None:
133
+ query = attn.norm_q(query)
134
+ if attn.norm_k is not None:
135
+ key = attn.norm_k(key)
136
+
137
+ if image_rotary_emb is not None:
138
+ from diffusers.models.embeddings import apply_rotary_emb
139
+ query = apply_rotary_emb(query, image_rotary_emb)
140
+ key = apply_rotary_emb(key, image_rotary_emb)
141
+
142
+ num_cond_blocks = self.n_loras
143
+ mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
144
+ mask[ :scaled_block_size, :] = 0 # First block_size row
145
+ for i in range(num_cond_blocks):
146
+ start = i * scaled_cond_size + scaled_block_size
147
+ end = (i + 1) * scaled_cond_size + scaled_block_size
148
+ mask[start:end, start:end] = 0 # Diagonal blocks
149
+ mask = mask * -1e10
150
+ mask = mask.to(query.dtype)
151
+
152
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
153
+ self.bank_attn = hidden_states[:, :, scaled_block_size:, :]
154
+
155
+ else:
156
+ query = attn.to_q(hidden_states)
157
+ key = attn.to_k(hidden_states)
158
+ value = attn.to_v(hidden_states)
159
+
160
+ inner_dim = query.shape[-1]
161
+ head_dim = inner_dim // attn.heads
162
+
163
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
164
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
165
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
166
+
167
+ key = torch.concat([key[:, :, :scaled_block_size, :], self.bank_kv[0]], dim=-2)
168
+ value = torch.concat([value[:, :, :scaled_block_size, :], self.bank_kv[1]], dim=-2)
169
+
170
+ if attn.norm_q is not None:
171
+ query = attn.norm_q(query)
172
+ if attn.norm_k is not None:
173
+ key = attn.norm_k(key)
174
+
175
+ if image_rotary_emb is not None:
176
+ from diffusers.models.embeddings import apply_rotary_emb
177
+ query = apply_rotary_emb(query, image_rotary_emb)
178
+ key = apply_rotary_emb(key, image_rotary_emb)
179
+
180
+ query = query[:, :, :scaled_block_size, :]
181
+
182
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=None)
183
+ hidden_states = torch.concat([hidden_states, self.bank_attn], dim=-2)
184
+
185
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
186
+ hidden_states = hidden_states.to(query.dtype)
187
+
188
+ cond_hidden_states = hidden_states[:, block_size:,:]
189
+ hidden_states = hidden_states[:, : block_size,:]
190
+
191
+ return hidden_states if not use_cond else (hidden_states, cond_hidden_states)
192
+
193
+
194
+ class MultiDoubleStreamBlockLoraProcessor(nn.Module):
195
+ def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1):
196
+ super().__init__()
197
+
198
+ # Initialize a list to store the LoRA layers
199
+ self.n_loras = n_loras
200
+ self.cond_width = cond_width
201
+ self.cond_height = cond_height
202
+ self.q_loras = nn.ModuleList([
203
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
204
+ for i in range(n_loras)
205
+ ])
206
+ self.k_loras = nn.ModuleList([
207
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
208
+ for i in range(n_loras)
209
+ ])
210
+ self.v_loras = nn.ModuleList([
211
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
212
+ for i in range(n_loras)
213
+ ])
214
+ self.proj_loras = nn.ModuleList([
215
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
216
+ for i in range(n_loras)
217
+ ])
218
+ self.lora_weights = lora_weights
219
+ self.bank_attn = None
220
+ self.bank_kv = []
221
+
222
+
223
+ def __call__(self,
224
+ attn: Attention,
225
+ hidden_states: torch.FloatTensor,
226
+ encoder_hidden_states: torch.FloatTensor = None,
227
+ attention_mask: Optional[torch.FloatTensor] = None,
228
+ image_rotary_emb: Optional[torch.Tensor] = None,
229
+ use_cond=False,
230
+ ) -> torch.FloatTensor:
231
+
232
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
233
+ cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
234
+ block_size = hidden_states.shape[1] - cond_size * self.n_loras
235
+ scaled_seq_len = encoder_hidden_states.shape[1] + hidden_states.shape[1]
236
+ scaled_cond_size = cond_size
237
+ scaled_block_size = scaled_seq_len - scaled_cond_size * self.n_loras
238
+
239
+ # `context` projections.
240
+ inner_dim = 3072
241
+ head_dim = inner_dim // attn.heads
242
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
243
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
244
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
245
+
246
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
247
+ batch_size, -1, attn.heads, head_dim
248
+ ).transpose(1, 2)
249
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
250
+ batch_size, -1, attn.heads, head_dim
251
+ ).transpose(1, 2)
252
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
253
+ batch_size, -1, attn.heads, head_dim
254
+ ).transpose(1, 2)
255
+
256
+ if attn.norm_added_q is not None:
257
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
258
+ if attn.norm_added_k is not None:
259
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
260
+
261
+ if len(self.bank_kv)== 0:
262
+ cache = True
263
+ else:
264
+ cache = False
265
+
266
+ if cache:
267
+
268
+ query = attn.to_q(hidden_states)
269
+ key = attn.to_k(hidden_states)
270
+ value = attn.to_v(hidden_states)
271
+ for i in range(self.n_loras):
272
+ query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
273
+ key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
274
+ value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
275
+
276
+ inner_dim = key.shape[-1]
277
+ head_dim = inner_dim // attn.heads
278
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
279
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
280
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
281
+
282
+
283
+ self.bank_kv.append(key[:, :, block_size:, :])
284
+ self.bank_kv.append(value[:, :, block_size:, :])
285
+
286
+ if attn.norm_q is not None:
287
+ query = attn.norm_q(query)
288
+ if attn.norm_k is not None:
289
+ key = attn.norm_k(key)
290
+
291
+ # attention
292
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
293
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
294
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
295
+
296
+ if image_rotary_emb is not None:
297
+ from diffusers.models.embeddings import apply_rotary_emb
298
+ query = apply_rotary_emb(query, image_rotary_emb)
299
+ key = apply_rotary_emb(key, image_rotary_emb)
300
+
301
+ num_cond_blocks = self.n_loras
302
+ mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
303
+ mask[ :scaled_block_size, :] = 0 # First block_size row
304
+ for i in range(num_cond_blocks):
305
+ start = i * scaled_cond_size + scaled_block_size
306
+ end = (i + 1) * scaled_cond_size + scaled_block_size
307
+ mask[start:end, start:end] = 0 # Diagonal blocks
308
+ mask = mask * -1e10
309
+ mask = mask.to(query.dtype)
310
+
311
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
312
+ self.bank_attn = hidden_states[:, :, scaled_block_size:, :]
313
+
314
+ else:
315
+ query = attn.to_q(hidden_states)
316
+ key = attn.to_k(hidden_states)
317
+ value = attn.to_v(hidden_states)
318
+
319
+ inner_dim = query.shape[-1]
320
+ head_dim = inner_dim // attn.heads
321
+
322
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
323
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
324
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
325
+
326
+ key = torch.concat([key[:, :, :block_size, :], self.bank_kv[0]], dim=-2)
327
+ value = torch.concat([value[:, :, :block_size, :], self.bank_kv[1]], dim=-2)
328
+
329
+ if attn.norm_q is not None:
330
+ query = attn.norm_q(query)
331
+ if attn.norm_k is not None:
332
+ key = attn.norm_k(key)
333
+
334
+ # attention
335
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
336
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
337
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
338
+
339
+ if image_rotary_emb is not None:
340
+ from diffusers.models.embeddings import apply_rotary_emb
341
+ query = apply_rotary_emb(query, image_rotary_emb)
342
+ key = apply_rotary_emb(key, image_rotary_emb)
343
+
344
+ query = query[:, :, :scaled_block_size, :]
345
+
346
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=None)
347
+ hidden_states = torch.concat([hidden_states, self.bank_attn], dim=-2)
348
+
349
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
350
+ hidden_states = hidden_states.to(query.dtype)
351
+
352
+ encoder_hidden_states, hidden_states = (
353
+ hidden_states[:, : encoder_hidden_states.shape[1]],
354
+ hidden_states[:, encoder_hidden_states.shape[1] :],
355
+ )
356
+
357
+ # Linear projection (with LoRA weight applied to each proj layer)
358
+ hidden_states = attn.to_out[0](hidden_states)
359
+ for i in range(self.n_loras):
360
+ hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
361
+ # dropout
362
+ hidden_states = attn.to_out[1](hidden_states)
363
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
364
+
365
+ cond_hidden_states = hidden_states[:, block_size:,:]
366
+ hidden_states = hidden_states[:, :block_size,:]
367
+
368
  return (hidden_states, encoder_hidden_states, cond_hidden_states) if use_cond else (encoder_hidden_states, hidden_states)