Any-to-Any

This is a shunt that takes in the t5-small and the vit-h-14 simultaneously.

The t5-small is used as a conditioning factor for normalization and guidance.

There are many possible toggles and many variations for this shunt to be used.

The only one I hooked up is the basic tool meant for simple text encoder guidance, then I shunted it into clip_embeds for a test - only to see it fall apart.

The results that worked with diffusers without a headache ended up being prompt_encode overriding with a monkey patch.

Drag and drop into colab and generate some sdxl images with it. Two nodes; one above the generator

Fiddle with the taps and mess with the settings to add additional or reduce guidance from the T5-small variations with your clip_l.

import safetensors.torch as st
import torch
from diffusers import StableDiffusionXLPipeline
from transformers import T5TokenizerFast, T5EncoderModel

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

# ─────────────────────────────────────────────────────────────
# β–‘ Two-Stream Shunt Adapter
# ─────────────────────────────────────────────────────────────
class TwoStreamShuntAdapter(nn.Module):
    """
    Cross-attentive adapter that aligns T5 and CLIP token streams.

    Returns:
        anchor     : (B, Lc, clip_dim)
        delta      : (B, Lc, clip_dim)
        log_sigma  : (B, Lc, clip_dim)  – log Οƒ, always finite
        attn_t2c   : (B, heads, Lt, Lc)
        attn_c2t   : (B, heads, Lc, Lt)
        tau        : (heads, 1, 1)       – per-head threshold param
        g_pred     : (B, 1)              – guidance-scale prediction
        gate       : (B, Lc, 1)          – per-token gate ∈ (0,1)
    """

    def __init__(
        self,
        t5_dim: int = 512,
        clip_dim: int = 768,
        bottleneck: int = 256,
        heads: int = 8,
        tau_init: float = 0.1,
        max_guidance: float = 10.0,
    ):
        super().__init__()
        print("TwoStreamShuntAdapter init")
        self.heads = heads
        self.bneck = bottleneck
        self.max_guidance = max_guidance

        # projections
        self.proj_t5   = nn.Linear(t5_dim,   bottleneck)
        self.proj_clip = nn.Linear(clip_dim, bottleneck)

        # cross-attention
        self.cross_t2c = nn.MultiheadAttention(
            bottleneck, heads, batch_first=True, dropout=0.1
        )
        self.cross_c2t = nn.MultiheadAttention(
            bottleneck, heads, batch_first=True, dropout=0.1
        )

        # head-wise Ο„
        self.tau = nn.Parameter(torch.full((heads, 1, 1), tau_init))

        # convolutional pocket residual (depth-wise)
        self.res1 = nn.Conv1d(
            bottleneck, bottleneck, 3, padding=1, groups=bottleneck
        )
        self.res2 = nn.Conv1d(
            bottleneck, bottleneck, 3, padding=1, groups=bottleneck
        )
        self.norm_res = nn.LayerNorm(bottleneck)

        # fusion + projections
        self.fuse = nn.Linear(2 * bottleneck, bottleneck)

        self.anchor_proj = nn.Sequential(
            nn.Linear(bottleneck, bottleneck), nn.GELU(),
            nn.Linear(bottleneck, clip_dim)
        )
        self.delta_proj = nn.Sequential(
            nn.Linear(bottleneck, bottleneck), nn.GELU(),
            nn.Linear(bottleneck, clip_dim)
        )
        self.logsig_proj = nn.Sequential(
            nn.Linear(bottleneck, bottleneck), nn.GELU(),
            nn.Linear(bottleneck, clip_dim)
        )
        self.gate_proj = nn.Sequential(
            nn.Linear(bottleneck, bottleneck), nn.GELU(),
            nn.Linear(bottleneck, 1), nn.Sigmoid()
        )
        self.guidance_proj = nn.Sequential(
            nn.LayerNorm(bottleneck), nn.Linear(bottleneck, 1), nn.Sigmoid()
        )

    def load_state_dict(self, args, **kwargs):
      # remove _orig_mod from state dict before applying.
      state_dict = {k.replace("_orig_mod.", ""): v for k, v in args.items()}
      super().load_state_dict(state_dict, **kwargs)

    def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):
        print("πŸ“£  SHUNT FORWARD CALLED")

        B, Lt, _ = t5_seq.size()
        _, Lc, _ = clip_seq.size()

        # 1) project into bottleneck
        t5_b   = self.proj_t5(t5_seq)      # (B, Lt, b)
        clip_b = self.proj_clip(clip_seq)  # (B, Lc, b)

        # 2) cross-attention
        t2c, attn_t2c = self.cross_t2c(
            t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False
        )
        c2t, attn_c2t = self.cross_c2t(
            clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False
        )

        # 3) convolutional pocket on T5β†’CLIP
        x = t2c.transpose(1, 2)                    # (B, b, Lt)
        x = F.gelu(self.res1(x))
        x = F.gelu(self.res2(x)).transpose(1, 2)   # (B, Lt, b)
        pocket = self.norm_res(t2c + x)            # (B, Lt, b)

        # 4) fuse pocket avg with C2T
        pocket_mean = pocket.mean(1, keepdim=True).expand(-1, Lc, -1)
        h = F.gelu(self.fuse(torch.cat([pocket_mean, c2t], -1)))  # (B, Lc, b)

        # 5) outputs
        anchor     = self.anchor_proj(h)                       # (B,Lc,768)
        delta_mean = self.delta_proj(h)                        # (B,Lc,768)
        log_sigma  = self.logsig_proj(h)                       # (B,Lc,768)
        gate       = self.gate_proj(h)                         # (B,Lc,1)
        delta      = delta_mean * gate                         # (B,Lc,768)

        g_tok  = self.guidance_proj(h).squeeze(-1)             # (B,Lc)
        g_pred = g_tok.mean(1, keepdim=True) * self.max_guidance
        
        #print(anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, gate)

        return anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, gate

# --- 1. load pipeline -------------------------------------------------
pipe = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float16).to("cuda")

# --- 2. load tiny-T5 & shunt (fp32) -----------------------------------
t5_tok = T5TokenizerFast.from_pretrained("t5-small")
t5_mod = T5EncoderModel.from_pretrained("t5-small").eval().to("cuda")
shunt  = TwoStreamShuntAdapter().float().eval().to("cuda")
shunt.load_state_dict( st.load_file("/content/drive/MyDrive/t5-clip-l-shunts/vitl14_t5small_shunt_vanilla_final.safetensors") )

# --- 3. wrap encode_prompt once ---------------------------------------
orig_encode = pipe.encode_prompt

config = {
    "strength": 1.0,
    "gate_gamma": 1.0,
    "tau_scale": 1.0,
    "guidance_gain": 1.0,
    "guidance_bias": 0.0
}


gen = torch.Generator(device="cuda").manual_seed(420)

Place this on another node so you don't reload over and over.

strength = 0

# the working version that can't be omitted, 
def stable_encode_prompt_shunted(self, *args, **kw):
    pe, ne, pool, npool = orig_encode(*args, **kw)   # regular call

    # πŸ‘‰ split: first 768 dims are CLIP-L, rest 1280 are CLIP-G
    clipL, clipG = pe[..., :768], pe[..., 768:]

    # build T5 batch (handles CFG dup automatically because
    # encode_prompt already concatenated negative & positive if needed)
    bsz = clipL.shape[0]
    texts = ["tmp"] * bsz        # dummy, we only care about hidden states
    t5_ids  = t5_tok(texts, return_tensors="pt").input_ids.to("cuda")
    t5_seq  = t5_mod(t5_ids).last_hidden_state        # (B,L,512)

    # run adapter in fp32
    delta = shunt(t5_seq.float(), clipL.float())[1]   # second output is Ξ”
    delta = delta * strength                              # << your strength knob
    clipL_shift = (clipL.float() + delta).to(clipL.dtype)

    pe_shifted = torch.cat([clipL_shift, clipG], dim=-1)
    return pe_shifted, ne, pool, npool
#-----------------------------------------------------------------------------------------

def encode_prompt_shunted(self, *a, **k):
    # 1) run the normal encoder with β€œstyle” & β€œcontext” already split
    pe, ne, pool, npool = orig_encode(*a, **k)          # (B,77,2048)

    # 2) split CLIP-L / CLIP-G
    clipL, clipG = pe[..., :768], pe[..., 768:]

    # 3) build T5 on the *context* text (it’s in k['prompt_2'])
    t5_ids = t5_tok([k.get("prompt_2")], return_tensors="pt").input_ids.to(pe.device)
    t5_seq = t5_mod(t5_ids).last_hidden_state.float()

    # 4) shunt β†’ Ξ”  (FP32 β†’ back-cast)
    Ξ” = shunt(t5_seq, clipL.float())[1].to(clipL.dtype)
    clipL_shift = clipL + Ξ” * strength

    # 5) concatenate back
    pe_shift = torch.cat([clipL_shift, clipG], dim=-1)
    return pe_shift, ne, pool, npool

pipe.encode_prompt = encode_prompt_shunted.__get__(pipe, type(pipe))




PROMPT = "a naturally lit and beautiful room with a photorealistic depiction of a woman"
PROMPT_2 = "a realistic depiction of a woman sitting on a chair at a coffee shop sipping coffee, the environment is beautiful"
NEG = "blurry, distorted, monochrome, greyscale, watermark"
STEPS = 50
base_strength = 0.5
base_cfg = 7.5


for i in range(0, 4):
  strength = base_strength + (i * 0.25)
  cfg = base_cfg - (i * 0.25)
  img = pipe(
      PROMPT,
      prompt_2=PROMPT_2,
      negative_prompt=NEG, 
      num_inference_steps=STEPS,
      cfg_scale=cfg,
      generator=torch.Generator(device="cuda").manual_seed(420)
      ).images[0]
  img.save(f"woman_cfg_{int(cfg*100)}_{int(strength*100)}.png")

# --- 4. generate -------------------------------------------------------
#img = pipe(
#    PROMPT,
#    negative_prompt=NEG, 
#    num_inference_steps=STEPS,
#    generator=torch.Generator(device="cuda").manual_seed(420)
#    ).images[0]
#img.save("majestic_baseline.png")#
#

#strength = 0.25
## --- 4. generate -------------------------------------------------------
#img = pipe(
#    PROMPT,
#    negative_prompt=NEG, 
#    num_inference_steps=STEPS,
#    generator=torch.Generator(device="cuda").manual_seed(420)
#    ).images[0]
#img.save("majestic_02.png")#

#strength = 0.5
## --- 4. generate -------------------------------------------------------
#img = pipe(
#    PROMPT,
#    negative_prompt=NEG, 
#    num_inference_steps=STEPS,
#    generator=torch.Generator(device="cuda").manual_seed(420)
#    ).images[0]
#img.save("majestic_05.png")#

#strength = 0.75
## --- 4. generate -------------------------------------------------------
#img = pipe(
#    PROMPT,
#    negative_prompt=NEG, 
#    num_inference_steps=STEPS,
#    generator=torch.Generator(device="cuda").manual_seed(420)
#    ).images[0]
#img.save("majestic_075.png")
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for AbstractPhil/t5-vit-14-v1

Base model

google-t5/t5-small
Finetuned
(1952)
this model

Dataset used to train AbstractPhil/t5-vit-14-v1