primerz's picture
Update ip_adapter/resampler.py
fd38570 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class FeedForward(nn.Module):
def __init__(self, dim, mult=4):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fc1 = nn.Linear(dim, int(dim * mult))
self.act = nn.GELU()
self.fc2 = nn.Linear(int(dim * mult), dim)
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
def forward(self, x):
return x + self.fc2(self.act(self.fc1(self.norm(x))))
def reshape_tensor(x, heads):
bs, length, _ = x.shape
return x.view(bs, length, heads, -1).transpose(1, 2)
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim)
self.to_kv = nn.Linear(dim, inner_dim * 2)
self.to_out = nn.Linear(inner_dim, dim)
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_kv.weight)
nn.init.xavier_uniform_(self.to_out.weight)
def forward(self, x, latents):
x = self.norm1(x)
latents = self.norm2(latents)
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q, k, v = map(lambda t: reshape_tensor(t, self.heads), (q, k, v))
attn_score = (q @ k.transpose(-2, -1)) * self.scale
attn_weight = F.softmax(attn_score, dim=-1)
out = (attn_weight @ v).transpose(1, 2).reshape(latents.shape)
return self.to_out(out)
class Resampler(nn.Module):
def __init__(self, dim=1024, depth=8, dim_head=64, heads=16, num_queries=8, embedding_dim=768, output_dim=1024, ff_mult=4):
super().__init__()
self.latents = nn.Parameter(torch.empty(1, num_queries, dim))
nn.init.normal_(self.latents, mean=0, std=dim**-0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([
nn.ModuleList([
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]) for _ in range(depth)
])
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
return self.norm_out(self.proj_out(latents))