AbstractPhil commited on
Commit
a4ecb4f
Β·
verified Β·
1 Parent(s): 9e54ad1

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +160 -0
model.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ADAPTER_CONFIG = {
2
+ "adapter_id": "003",
3
+ "name": "DualShuntAdapter-G",
4
+
5
+ "t5": {
6
+ "model": "google/flan-t5-base",
7
+ "hidden_size": 768,
8
+ },
9
+ "clip": {
10
+ "model": "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
11
+ "hidden_size": 1280,
12
+ },
13
+
14
+ "bottleneck": 640,
15
+ "heads": 20,
16
+
17
+ "tau_init": 0.1,
18
+ "max_guidance": 10.0,
19
+
20
+ "proj_layers": 2,
21
+ "layer_norm": True,
22
+ "dropout": 0.1,
23
+ "use_dropout": True,
24
+ "use_proj_stack": True,
25
+ "assert_input_dims": True,
26
+
27
+ "routing": {
28
+ "type": "cross_attention",
29
+ "enable_causal_mask": False,
30
+ "bidirectional": True
31
+ },
32
+
33
+ "version": "v0.3.2",
34
+ "description": "Final Dual Shunt Adapter with projection stack, dropout, and stacked residual refinement pocket."
35
+ }
36
+
37
+ import torch
38
+ import torch.nn as nn
39
+ import torch.nn.functional as F
40
+
41
+ # ─── Residual Pocket Block ───────────────────────────────────
42
+ class BottleneckResBlock(nn.Module):
43
+ def __init__(self, dim, kernel=3, dropout=0.1):
44
+ super().__init__()
45
+ self.norm = nn.LayerNorm(dim)
46
+ self.conv = nn.Conv1d(dim, dim, kernel_size=kernel, padding=kernel // 2, groups=1)
47
+ self.proj = nn.Sequential(
48
+ nn.Linear(dim, dim * 2),
49
+ nn.GELU(),
50
+ nn.Linear(dim * 2, dim),
51
+ nn.Dropout(dropout)
52
+ )
53
+
54
+ def forward(self, x):
55
+ residual = x
56
+ x = self.norm(x)
57
+ x = x.transpose(1, 2)
58
+ x = self.conv(x).transpose(1, 2)
59
+ return residual + self.proj(x)
60
+
61
+ # ─── Two Stream Shunt Adapter ──────────────────────────────────────
62
+ class TwoStreamShuntAdapter(nn.Module):
63
+ def __init__(self, config: dict):
64
+ super().__init__()
65
+ self.config = config
66
+ self.t5_dim = config["t5"]["hidden_size"]
67
+ self.clip_dim = config["clip"]["hidden_size"]
68
+ self.bneck = config["bottleneck"]
69
+ self.heads = config["heads"]
70
+ self.tau_init = config["tau_init"]
71
+ self.max_guidance = config["max_guidance"]
72
+
73
+ use_norm = config.get("layer_norm", True)
74
+ use_do = config.get("use_dropout", True)
75
+ do_p = config.get("dropout", 0.1)
76
+ proj_depth = config.get("proj_layers", 2)
77
+
78
+ def build_projection(input_dim, output_dim):
79
+ layers = []
80
+ last_dim = input_dim
81
+ if use_norm:
82
+ layers.append(nn.LayerNorm(last_dim))
83
+ for i in range(proj_depth):
84
+ next_dim = self.bneck * (2 if i == 0 and proj_depth > 1 else 1)
85
+ layers.append(nn.Linear(last_dim, next_dim))
86
+ layers.append(nn.GELU())
87
+ if use_do:
88
+ layers.append(nn.Dropout(do_p))
89
+ last_dim = next_dim
90
+ layers.append(nn.Linear(last_dim, output_dim))
91
+ return nn.Sequential(*layers)
92
+
93
+ # Projections
94
+ self.proj_t5 = build_projection(self.t5_dim, self.bneck)
95
+ self.proj_clip = build_projection(self.clip_dim, self.bneck)
96
+
97
+ # Attention
98
+ self.cross_t2c = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)
99
+ self.cross_c2t = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)
100
+ self.tau = nn.Parameter(torch.full((self.heads, 1, 1), self.tau_init))
101
+
102
+ # Residual Pocket
103
+ self.pocket_blocks = nn.Sequential(
104
+ BottleneckResBlock(self.bneck, dropout=do_p),
105
+ BottleneckResBlock(self.bneck, dropout=do_p)
106
+ )
107
+
108
+ # Fuse
109
+ self.fuse = nn.Sequential(
110
+ nn.LayerNorm(2 * self.bneck),
111
+ nn.Linear(2 * self.bneck, self.bneck * 2),
112
+ nn.GELU(),
113
+ nn.Linear(self.bneck * 2, self.bneck)
114
+ )
115
+
116
+ # Output Projections
117
+ self.anchor_proj = build_projection(self.bneck, self.clip_dim)
118
+ self.delta_proj = build_projection(self.bneck, self.clip_dim)
119
+ self.logsig_proj = build_projection(self.bneck, self.clip_dim)
120
+
121
+ self.gate_proj = nn.Sequential(
122
+ nn.LayerNorm(self.bneck),
123
+ nn.Linear(self.bneck, self.bneck),
124
+ nn.GELU(),
125
+ nn.Linear(self.bneck, 1),
126
+ nn.Tanh(),
127
+ nn.Sigmoid()
128
+ )
129
+
130
+ self.guidance_proj = nn.Sequential(
131
+ nn.LayerNorm(self.bneck),
132
+ nn.Linear(self.bneck, 1),
133
+ nn.Sigmoid()
134
+ )
135
+
136
+ def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):
137
+ if self.config.get("assert_input_dims", True):
138
+ assert t5_seq.size(-1) == self.t5_dim
139
+ assert clip_seq.size(-1) == self.clip_dim
140
+
141
+ t5_b = self.proj_t5(t5_seq)
142
+ clip_b = self.proj_clip(clip_seq)
143
+
144
+ t2c, attn_t2c = self.cross_t2c(t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False)
145
+ c2t, attn_c2t = self.cross_c2t(clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False)
146
+
147
+ pocket = self.pocket_blocks(t2c)
148
+
149
+ pocket_mean = pocket.mean(1, keepdim=True).expand(-1, clip_b.size(1), -1)
150
+ h = self.fuse(torch.cat([pocket_mean, c2t], dim=-1))
151
+
152
+ anchor = self.anchor_proj(h)
153
+ delta = self.delta_proj(h) * self.gate_proj(h)
154
+ log_sigma = self.logsig_proj(h)
155
+
156
+ g_tok = self.guidance_proj(h).squeeze(-1)
157
+ g_pred = g_tok.mean(1, keepdim=True) * self.max_guidance
158
+
159
+ return anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, self.gate_proj(h)
160
+