Any-to-Any
AbstractPhil commited on
Commit
eb5263e
·
verified ·
1 Parent(s): 99cb0de

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +299 -3
README.md CHANGED
@@ -1,3 +1,299 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+
6
+ This is a shunt that takes in the t5-small and the vit-h-14 simultaneously.
7
+
8
+ The t5-small is used as a conditioning factor for normalization and guidance.
9
+
10
+ There are many possible toggles and many variations for this shunt to be used.
11
+
12
+ The only one I hooked up is the basic tool meant for simple text encoder guidance, then I shunted it into clip_embeds for a test.
13
+
14
+
15
+
16
+
17
+ ---
18
+ import safetensors.torch as st
19
+ import torch
20
+ from diffusers import StableDiffusionXLPipeline
21
+ from transformers import T5TokenizerFast, T5EncoderModel
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ from torch.utils.data import DataLoader
28
+ from tqdm.auto import tqdm
29
+
30
+ # ─────────────────────────────────────────────────────────────
31
+ # ░ Two-Stream Shunt Adapter
32
+ # ─────────────────────────────────────────────────────────────
33
+ class TwoStreamShuntAdapter(nn.Module):
34
+ """
35
+ Cross-attentive adapter that aligns T5 and CLIP token streams.
36
+
37
+ Returns:
38
+ anchor : (B, Lc, clip_dim)
39
+ delta : (B, Lc, clip_dim)
40
+ log_sigma : (B, Lc, clip_dim) – log σ, always finite
41
+ attn_t2c : (B, heads, Lt, Lc)
42
+ attn_c2t : (B, heads, Lc, Lt)
43
+ tau : (heads, 1, 1) – per-head threshold param
44
+ g_pred : (B, 1) – guidance-scale prediction
45
+ gate : (B, Lc, 1) – per-token gate ∈ (0,1)
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ t5_dim: int = 512,
51
+ clip_dim: int = 768,
52
+ bottleneck: int = 256,
53
+ heads: int = 8,
54
+ tau_init: float = 0.1,
55
+ max_guidance: float = 10.0,
56
+ ):
57
+ super().__init__()
58
+ print("TwoStreamShuntAdapter init")
59
+ self.heads = heads
60
+ self.bneck = bottleneck
61
+ self.max_guidance = max_guidance
62
+
63
+ # projections
64
+ self.proj_t5 = nn.Linear(t5_dim, bottleneck)
65
+ self.proj_clip = nn.Linear(clip_dim, bottleneck)
66
+
67
+ # cross-attention
68
+ self.cross_t2c = nn.MultiheadAttention(
69
+ bottleneck, heads, batch_first=True, dropout=0.1
70
+ )
71
+ self.cross_c2t = nn.MultiheadAttention(
72
+ bottleneck, heads, batch_first=True, dropout=0.1
73
+ )
74
+
75
+ # head-wise τ
76
+ self.tau = nn.Parameter(torch.full((heads, 1, 1), tau_init))
77
+
78
+ # convolutional pocket residual (depth-wise)
79
+ self.res1 = nn.Conv1d(
80
+ bottleneck, bottleneck, 3, padding=1, groups=bottleneck
81
+ )
82
+ self.res2 = nn.Conv1d(
83
+ bottleneck, bottleneck, 3, padding=1, groups=bottleneck
84
+ )
85
+ self.norm_res = nn.LayerNorm(bottleneck)
86
+
87
+ # fusion + projections
88
+ self.fuse = nn.Linear(2 * bottleneck, bottleneck)
89
+
90
+ self.anchor_proj = nn.Sequential(
91
+ nn.Linear(bottleneck, bottleneck), nn.GELU(),
92
+ nn.Linear(bottleneck, clip_dim)
93
+ )
94
+ self.delta_proj = nn.Sequential(
95
+ nn.Linear(bottleneck, bottleneck), nn.GELU(),
96
+ nn.Linear(bottleneck, clip_dim)
97
+ )
98
+ self.logsig_proj = nn.Sequential(
99
+ nn.Linear(bottleneck, bottleneck), nn.GELU(),
100
+ nn.Linear(bottleneck, clip_dim)
101
+ )
102
+ self.gate_proj = nn.Sequential(
103
+ nn.Linear(bottleneck, bottleneck), nn.GELU(),
104
+ nn.Linear(bottleneck, 1), nn.Sigmoid()
105
+ )
106
+ self.guidance_proj = nn.Sequential(
107
+ nn.LayerNorm(bottleneck), nn.Linear(bottleneck, 1), nn.Sigmoid()
108
+ )
109
+
110
+ def load_state_dict(self, args, **kwargs):
111
+ # remove _orig_mod from state dict before applying.
112
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in args.items()}
113
+ super().load_state_dict(state_dict, **kwargs)
114
+
115
+ def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):
116
+ print("📣 SHUNT FORWARD CALLED")
117
+
118
+ B, Lt, _ = t5_seq.size()
119
+ _, Lc, _ = clip_seq.size()
120
+
121
+ # 1) project into bottleneck
122
+ t5_b = self.proj_t5(t5_seq) # (B, Lt, b)
123
+ clip_b = self.proj_clip(clip_seq) # (B, Lc, b)
124
+
125
+ # 2) cross-attention
126
+ t2c, attn_t2c = self.cross_t2c(
127
+ t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False
128
+ )
129
+ c2t, attn_c2t = self.cross_c2t(
130
+ clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False
131
+ )
132
+
133
+ # 3) convolutional pocket on T5→CLIP
134
+ x = t2c.transpose(1, 2) # (B, b, Lt)
135
+ x = F.gelu(self.res1(x))
136
+ x = F.gelu(self.res2(x)).transpose(1, 2) # (B, Lt, b)
137
+ pocket = self.norm_res(t2c + x) # (B, Lt, b)
138
+
139
+ # 4) fuse pocket avg with C2T
140
+ pocket_mean = pocket.mean(1, keepdim=True).expand(-1, Lc, -1)
141
+ h = F.gelu(self.fuse(torch.cat([pocket_mean, c2t], -1))) # (B, Lc, b)
142
+
143
+ # 5) outputs
144
+ anchor = self.anchor_proj(h) # (B,Lc,768)
145
+ delta_mean = self.delta_proj(h) # (B,Lc,768)
146
+ log_sigma = self.logsig_proj(h) # (B,Lc,768)
147
+ gate = self.gate_proj(h) # (B,Lc,1)
148
+ delta = delta_mean * gate # (B,Lc,768)
149
+
150
+ g_tok = self.guidance_proj(h).squeeze(-1) # (B,Lc)
151
+ g_pred = g_tok.mean(1, keepdim=True) * self.max_guidance
152
+
153
+ #print(anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, gate)
154
+
155
+ return anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, gate
156
+
157
+ # --- 1. load pipeline -------------------------------------------------
158
+ pipe = StableDiffusionXLPipeline.from_pretrained(
159
+ "stabilityai/stable-diffusion-xl-base-1.0",
160
+ torch_dtype=torch.float16).to("cuda")
161
+
162
+ # --- 2. load tiny-T5 & shunt (fp32) -----------------------------------
163
+ t5_tok = T5TokenizerFast.from_pretrained("t5-small")
164
+ t5_mod = T5EncoderModel.from_pretrained("t5-small").eval().to("cuda")
165
+ shunt = TwoStreamShuntAdapter().float().eval().to("cuda")
166
+ shunt.load_state_dict( st.load_file("/content/drive/MyDrive/t5-clip-l-shunts/vitl14_t5small_shunt_vanilla_final.safetensors") )
167
+
168
+ # --- 3. wrap encode_prompt once ---------------------------------------
169
+ orig_encode = pipe.encode_prompt
170
+
171
+ config = {
172
+ "strength": 1.0,
173
+ "gate_gamma": 1.0,
174
+ "tau_scale": 1.0,
175
+ "guidance_gain": 1.0,
176
+ "guidance_bias": 0.0
177
+ }
178
+
179
+
180
+ gen = torch.Generator(device="cuda").manual_seed(420)
181
+
182
+
183
+
184
+ strength = 0
185
+
186
+ # the working version that can't be omitted,
187
+ def stable_encode_prompt_shunted(self, *args, **kw):
188
+ pe, ne, pool, npool = orig_encode(*args, **kw) # regular call
189
+
190
+ # 👉 split: first 768 dims are CLIP-L, rest 1280 are CLIP-G
191
+ clipL, clipG = pe[..., :768], pe[..., 768:]
192
+
193
+ # build T5 batch (handles CFG dup automatically because
194
+ # encode_prompt already concatenated negative & positive if needed)
195
+ bsz = clipL.shape[0]
196
+ texts = ["tmp"] * bsz # dummy, we only care about hidden states
197
+ t5_ids = t5_tok(texts, return_tensors="pt").input_ids.to("cuda")
198
+ t5_seq = t5_mod(t5_ids).last_hidden_state # (B,L,512)
199
+
200
+ # run adapter in fp32
201
+ delta = shunt(t5_seq.float(), clipL.float())[1] # second output is Δ
202
+ delta = delta * strength # << your strength knob
203
+ clipL_shift = (clipL.float() + delta).to(clipL.dtype)
204
+
205
+ pe_shifted = torch.cat([clipL_shift, clipG], dim=-1)
206
+ return pe_shifted, ne, pool, npool
207
+ #-----------------------------------------------------------------------------------------
208
+
209
+ def encode_prompt_shunted(self, *a, **k):
210
+ # 1) run the normal encoder with “style” & “context” already split
211
+ pe, ne, pool, npool = orig_encode(*a, **k) # (B,77,2048)
212
+
213
+ # 2) split CLIP-L / CLIP-G
214
+ clipL, clipG = pe[..., :768], pe[..., 768:]
215
+
216
+ # 3) build T5 on the *context* text (it’s in k['prompt_2'])
217
+ t5_ids = t5_tok([k.get("prompt_2")], return_tensors="pt").input_ids.to(pe.device)
218
+ t5_seq = t5_mod(t5_ids).last_hidden_state.float()
219
+
220
+ # 4) shunt → Δ (FP32 → back-cast)
221
+ Δ = shunt(t5_seq, clipL.float())[1].to(clipL.dtype)
222
+ clipL_shift = clipL + Δ * strength
223
+
224
+ # 5) concatenate back
225
+ pe_shift = torch.cat([clipL_shift, clipG], dim=-1)
226
+ return pe_shift, ne, pool, npool
227
+
228
+ pipe.encode_prompt = encode_prompt_shunted.__get__(pipe, type(pipe))
229
+
230
+
231
+
232
+
233
+ PROMPT = "a naturally lit and beautiful room with a photorealistic depiction of a woman"
234
+ PROMPT_2 = "a realistic depiction of a woman sitting on a chair at a coffee shop sipping coffee, the environment is beautiful"
235
+ NEG = "blurry, distorted, monochrome, greyscale, watermark"
236
+ STEPS = 50
237
+ base_strength = 0.5
238
+ base_cfg = 7.5
239
+
240
+
241
+ for i in range(0, 4):
242
+ strength = base_strength + (i * 0.25)
243
+ cfg = base_cfg - (i * 0.25)
244
+ img = pipe(
245
+ PROMPT,
246
+ prompt_2=PROMPT_2,
247
+ negative_prompt=NEG,
248
+ num_inference_steps=STEPS,
249
+ cfg_scale=cfg,
250
+ generator=torch.Generator(device="cuda").manual_seed(420)
251
+ ).images[0]
252
+ img.save(f"woman_cfg_{int(cfg*100)}_{int(strength*100)}.png")
253
+
254
+ # --- 4. generate -------------------------------------------------------
255
+ #img = pipe(
256
+ # PROMPT,
257
+ # negative_prompt=NEG,
258
+ # num_inference_steps=STEPS,
259
+ # generator=torch.Generator(device="cuda").manual_seed(420)
260
+ # ).images[0]
261
+ #img.save("majestic_baseline.png")#
262
+ #
263
+
264
+ #strength = 0.25
265
+ ## --- 4. generate -------------------------------------------------------
266
+ #img = pipe(
267
+ # PROMPT,
268
+ # negative_prompt=NEG,
269
+ # num_inference_steps=STEPS,
270
+ # generator=torch.Generator(device="cuda").manual_seed(420)
271
+ # ).images[0]
272
+ #img.save("majestic_02.png")#
273
+
274
+ #strength = 0.5
275
+ ## --- 4. generate -------------------------------------------------------
276
+ #img = pipe(
277
+ # PROMPT,
278
+ # negative_prompt=NEG,
279
+ # num_inference_steps=STEPS,
280
+ # generator=torch.Generator(device="cuda").manual_seed(420)
281
+ # ).images[0]
282
+ #img.save("majestic_05.png")#
283
+
284
+ #strength = 0.75
285
+ ## --- 4. generate -------------------------------------------------------
286
+ #img = pipe(
287
+ # PROMPT,
288
+ # negative_prompt=NEG,
289
+ # num_inference_steps=STEPS,
290
+ # generator=torch.Generator(device="cuda").manual_seed(420)
291
+ # ).images[0]
292
+ #img.save("majestic_075.png")
293
+
294
+
295
+
296
+
297
+
298
+
299
+