jihao
commited on
Commit
·
f73bf08
1
Parent(s):
9d021bd
update eval files
Browse files- eva_vit_model/__init__.py +1 -0
- eva_vit_model/eva_vit.py +575 -0
- eva_vit_model/rope.py +137 -0
- eva_vit_model/transformer.py +625 -0
- eva_vit_model/uta_clip.py +31 -0
- imagenet_zeroshot_data.py +254 -0
- imagenet_zeroshot_eval.py +108 -0
- requirements.txt +6 -0
eva_vit_model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .uta_clip import CLIP
|
eva_vit_model/eva_vit.py
ADDED
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Adapted from https://github.com/microsoft/unilm/tree/master/beit
|
3 |
+
# --------------------------------------------------------
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
from functools import partial
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
try:
|
11 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
12 |
+
except:
|
13 |
+
from timm.layers import drop_path, to_2tuple, trunc_normal_
|
14 |
+
|
15 |
+
from .transformer import PatchDropout, LayerNorm
|
16 |
+
from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
|
17 |
+
|
18 |
+
if os.getenv('ENV_TYPE') == 'deepspeed':
|
19 |
+
try:
|
20 |
+
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
|
21 |
+
except:
|
22 |
+
from torch.utils.checkpoint import checkpoint
|
23 |
+
else:
|
24 |
+
from torch.utils.checkpoint import checkpoint
|
25 |
+
|
26 |
+
try:
|
27 |
+
import xformers.ops as xops
|
28 |
+
except ImportError:
|
29 |
+
xops = None
|
30 |
+
print("Please 'pip install xformers'")
|
31 |
+
|
32 |
+
|
33 |
+
class DropPath(nn.Module):
|
34 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
35 |
+
"""
|
36 |
+
def __init__(self, drop_prob=None):
|
37 |
+
super(DropPath, self).__init__()
|
38 |
+
self.drop_prob = drop_prob
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
return drop_path(x, self.drop_prob, self.training)
|
42 |
+
|
43 |
+
def extra_repr(self) -> str:
|
44 |
+
return 'p={}'.format(self.drop_prob)
|
45 |
+
|
46 |
+
|
47 |
+
class Mlp(nn.Module):
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
in_features,
|
51 |
+
hidden_features=None,
|
52 |
+
out_features=None,
|
53 |
+
act_layer=nn.GELU,
|
54 |
+
norm_layer=nn.LayerNorm,
|
55 |
+
drop=0.,
|
56 |
+
subln=False,
|
57 |
+
|
58 |
+
):
|
59 |
+
super().__init__()
|
60 |
+
out_features = out_features or in_features
|
61 |
+
hidden_features = hidden_features or in_features
|
62 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
63 |
+
self.act = act_layer()
|
64 |
+
|
65 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
66 |
+
|
67 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
68 |
+
self.drop = nn.Dropout(drop)
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
x = self.fc1(x)
|
72 |
+
x = self.act(x)
|
73 |
+
# x = self.drop(x)
|
74 |
+
# commit this for the orignal BERT implement
|
75 |
+
x = self.ffn_ln(x)
|
76 |
+
|
77 |
+
x = self.fc2(x)
|
78 |
+
x = self.drop(x)
|
79 |
+
return x
|
80 |
+
|
81 |
+
class SwiGLU(nn.Module):
|
82 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
|
83 |
+
norm_layer=nn.LayerNorm, subln=False):
|
84 |
+
super().__init__()
|
85 |
+
out_features = out_features or in_features
|
86 |
+
hidden_features = hidden_features or in_features
|
87 |
+
|
88 |
+
self.w1 = nn.Linear(in_features, hidden_features)
|
89 |
+
self.w2 = nn.Linear(in_features, hidden_features)
|
90 |
+
|
91 |
+
self.act = act_layer()
|
92 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
93 |
+
self.w3 = nn.Linear(hidden_features, out_features)
|
94 |
+
|
95 |
+
self.drop = nn.Dropout(drop)
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
x1 = self.w1(x)
|
99 |
+
x2 = self.w2(x)
|
100 |
+
hidden = self.act(x1) * x2
|
101 |
+
x = self.ffn_ln(hidden)
|
102 |
+
x = self.w3(x)
|
103 |
+
x = self.drop(x)
|
104 |
+
return x
|
105 |
+
|
106 |
+
class Attention(nn.Module):
|
107 |
+
def __init__(
|
108 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
109 |
+
proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
|
110 |
+
super().__init__()
|
111 |
+
self.num_heads = num_heads
|
112 |
+
head_dim = dim // num_heads
|
113 |
+
if attn_head_dim is not None:
|
114 |
+
head_dim = attn_head_dim
|
115 |
+
all_head_dim = head_dim * self.num_heads
|
116 |
+
self.scale = qk_scale or head_dim ** -0.5
|
117 |
+
|
118 |
+
self.subln = subln
|
119 |
+
if self.subln:
|
120 |
+
self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
|
121 |
+
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
|
122 |
+
self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
|
123 |
+
else:
|
124 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
125 |
+
|
126 |
+
if qkv_bias:
|
127 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
128 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
129 |
+
else:
|
130 |
+
self.q_bias = None
|
131 |
+
self.v_bias = None
|
132 |
+
|
133 |
+
if window_size:
|
134 |
+
self.window_size = window_size
|
135 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
136 |
+
self.relative_position_bias_table = nn.Parameter(
|
137 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
138 |
+
# cls to token & token 2 cls & cls to cls
|
139 |
+
|
140 |
+
# get pair-wise relative position index for each token inside the window
|
141 |
+
coords_h = torch.arange(window_size[0])
|
142 |
+
coords_w = torch.arange(window_size[1])
|
143 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
144 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
145 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
146 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
147 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
148 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
149 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
150 |
+
relative_position_index = \
|
151 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
|
152 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
153 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
154 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
155 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
156 |
+
|
157 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
158 |
+
else:
|
159 |
+
self.window_size = None
|
160 |
+
self.relative_position_bias_table = None
|
161 |
+
self.relative_position_index = None
|
162 |
+
|
163 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
164 |
+
# self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
|
165 |
+
self.inner_attn_ln = nn.Identity()
|
166 |
+
# self.proj = nn.Linear(all_head_dim, all_head_dim)
|
167 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
168 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
169 |
+
self.xattn = xattn
|
170 |
+
self.xattn_drop = attn_drop
|
171 |
+
|
172 |
+
self.rope = rope
|
173 |
+
|
174 |
+
def forward(self, x, rel_pos_bias=None, attn_mask=None):
|
175 |
+
B, N, C = x.shape
|
176 |
+
if self.subln:
|
177 |
+
q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
|
178 |
+
k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
|
179 |
+
v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
|
180 |
+
|
181 |
+
q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
|
182 |
+
k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
183 |
+
v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
184 |
+
else:
|
185 |
+
|
186 |
+
qkv_bias = None
|
187 |
+
if self.q_bias is not None:
|
188 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
189 |
+
|
190 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
191 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
|
192 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
193 |
+
|
194 |
+
if self.rope:
|
195 |
+
# slightly fast impl
|
196 |
+
q_t = q[:, :, 1:, :]
|
197 |
+
ro_q_t = self.rope(q_t)
|
198 |
+
q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
|
199 |
+
|
200 |
+
k_t = k[:, :, 1:, :]
|
201 |
+
ro_k_t = self.rope(k_t)
|
202 |
+
k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
|
203 |
+
|
204 |
+
if self.xattn:
|
205 |
+
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
|
206 |
+
k = k.permute(0, 2, 1, 3)
|
207 |
+
v = v.permute(0, 2, 1, 3)
|
208 |
+
|
209 |
+
x = xops.memory_efficient_attention(
|
210 |
+
q, k, v,
|
211 |
+
p=self.xattn_drop,
|
212 |
+
scale=self.scale,
|
213 |
+
)
|
214 |
+
x = x.reshape(B, N, -1)
|
215 |
+
x = self.inner_attn_ln(x)
|
216 |
+
x = self.proj(x)
|
217 |
+
x = self.proj_drop(x)
|
218 |
+
else:
|
219 |
+
q = q * self.scale
|
220 |
+
attn = (q @ k.transpose(-2, -1))
|
221 |
+
|
222 |
+
if self.relative_position_bias_table is not None:
|
223 |
+
relative_position_bias = \
|
224 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
225 |
+
self.window_size[0] * self.window_size[1] + 1,
|
226 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
227 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
228 |
+
attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
|
229 |
+
|
230 |
+
if rel_pos_bias is not None:
|
231 |
+
attn = attn + rel_pos_bias.type_as(attn)
|
232 |
+
|
233 |
+
if attn_mask is not None:
|
234 |
+
attn_mask = attn_mask.bool()
|
235 |
+
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
|
236 |
+
|
237 |
+
attn = attn.softmax(dim=-1)
|
238 |
+
attn = self.attn_drop(attn)
|
239 |
+
|
240 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
241 |
+
x = self.inner_attn_ln(x)
|
242 |
+
x = self.proj(x)
|
243 |
+
x = self.proj_drop(x)
|
244 |
+
return x
|
245 |
+
|
246 |
+
|
247 |
+
class Block(nn.Module):
|
248 |
+
|
249 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
250 |
+
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
251 |
+
window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
|
252 |
+
subln=False, naiveswiglu=False):
|
253 |
+
super().__init__()
|
254 |
+
self.norm1 = norm_layer(dim)
|
255 |
+
self.attn = Attention(
|
256 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
257 |
+
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
|
258 |
+
xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
|
259 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
260 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
261 |
+
self.norm2 = norm_layer(dim)
|
262 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
263 |
+
|
264 |
+
if naiveswiglu:
|
265 |
+
self.mlp = SwiGLU(
|
266 |
+
in_features=dim,
|
267 |
+
hidden_features=mlp_hidden_dim,
|
268 |
+
subln=subln,
|
269 |
+
norm_layer=norm_layer,
|
270 |
+
)
|
271 |
+
else:
|
272 |
+
self.mlp = Mlp(
|
273 |
+
in_features=dim,
|
274 |
+
hidden_features=mlp_hidden_dim,
|
275 |
+
act_layer=act_layer,
|
276 |
+
subln=subln,
|
277 |
+
drop=drop
|
278 |
+
)
|
279 |
+
|
280 |
+
if init_values is not None and init_values > 0:
|
281 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
282 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
283 |
+
else:
|
284 |
+
self.gamma_1, self.gamma_2 = None, None
|
285 |
+
|
286 |
+
self.postnorm = postnorm
|
287 |
+
|
288 |
+
def forward(self, x, rel_pos_bias=None, attn_mask=None):
|
289 |
+
if self.gamma_1 is None:
|
290 |
+
if self.postnorm:
|
291 |
+
x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
|
292 |
+
x = x + self.drop_path(self.norm2(self.mlp(x)))
|
293 |
+
else:
|
294 |
+
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
|
295 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
296 |
+
else:
|
297 |
+
if self.postnorm:
|
298 |
+
x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
|
299 |
+
x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
|
300 |
+
else:
|
301 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
|
302 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
303 |
+
return x
|
304 |
+
|
305 |
+
|
306 |
+
class PatchEmbed(nn.Module):
|
307 |
+
""" Image to Patch Embedding
|
308 |
+
"""
|
309 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
310 |
+
super().__init__()
|
311 |
+
img_size = to_2tuple(img_size)
|
312 |
+
patch_size = to_2tuple(patch_size)
|
313 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
314 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
315 |
+
self.img_size = img_size
|
316 |
+
self.patch_size = patch_size
|
317 |
+
self.num_patches = num_patches
|
318 |
+
|
319 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
320 |
+
|
321 |
+
def forward(self, x, **kwargs):
|
322 |
+
B, C, H, W = x.shape
|
323 |
+
# FIXME look at relaxing size constraints
|
324 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
325 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
326 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
327 |
+
return x
|
328 |
+
|
329 |
+
|
330 |
+
class RelativePositionBias(nn.Module):
|
331 |
+
|
332 |
+
def __init__(self, window_size, num_heads):
|
333 |
+
super().__init__()
|
334 |
+
self.window_size = window_size
|
335 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
336 |
+
self.relative_position_bias_table = nn.Parameter(
|
337 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
338 |
+
# cls to token & token 2 cls & cls to cls
|
339 |
+
|
340 |
+
# get pair-wise relative position index for each token inside the window
|
341 |
+
coords_h = torch.arange(window_size[0])
|
342 |
+
coords_w = torch.arange(window_size[1])
|
343 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
344 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
345 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
346 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
347 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
348 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
349 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
350 |
+
relative_position_index = \
|
351 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
352 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
353 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
354 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
355 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
356 |
+
|
357 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
358 |
+
|
359 |
+
def forward(self):
|
360 |
+
relative_position_bias = \
|
361 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
362 |
+
self.window_size[0] * self.window_size[1] + 1,
|
363 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
364 |
+
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
365 |
+
|
366 |
+
|
367 |
+
class EVAVisionTransformer(nn.Module):
|
368 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
369 |
+
"""
|
370 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
371 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
372 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
|
373 |
+
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
|
374 |
+
use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
|
375 |
+
pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False, head_2mlp=False):
|
376 |
+
super().__init__()
|
377 |
+
self.image_size = img_size
|
378 |
+
self.num_classes = num_classes
|
379 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
380 |
+
self.head_2mlp = head_2mlp
|
381 |
+
|
382 |
+
self.patch_embed = PatchEmbed(
|
383 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
384 |
+
num_patches = self.patch_embed.num_patches
|
385 |
+
|
386 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
387 |
+
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
388 |
+
if use_abs_pos_emb:
|
389 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
390 |
+
else:
|
391 |
+
self.pos_embed = None
|
392 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
393 |
+
|
394 |
+
if use_shared_rel_pos_bias:
|
395 |
+
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
|
396 |
+
else:
|
397 |
+
self.rel_pos_bias = None
|
398 |
+
|
399 |
+
if rope:
|
400 |
+
half_head_dim = embed_dim // num_heads // 2
|
401 |
+
hw_seq_len = img_size // patch_size
|
402 |
+
self.rope = VisionRotaryEmbeddingFast(
|
403 |
+
dim=half_head_dim,
|
404 |
+
pt_seq_len=pt_hw_seq_len,
|
405 |
+
ft_seq_len=hw_seq_len if intp_freq else None,
|
406 |
+
# patch_dropout=patch_dropout
|
407 |
+
)
|
408 |
+
else:
|
409 |
+
self.rope = None
|
410 |
+
|
411 |
+
self.naiveswiglu = naiveswiglu
|
412 |
+
|
413 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
414 |
+
self.use_rel_pos_bias = use_rel_pos_bias
|
415 |
+
self.blocks = nn.ModuleList([
|
416 |
+
Block(
|
417 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
418 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
419 |
+
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
|
420 |
+
xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
|
421 |
+
for i in range(depth)])
|
422 |
+
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
423 |
+
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
424 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
425 |
+
|
426 |
+
if self.pos_embed is not None:
|
427 |
+
trunc_normal_(self.pos_embed, std=.02)
|
428 |
+
|
429 |
+
trunc_normal_(self.cls_token, std=.02)
|
430 |
+
# trunc_normal_(self.mask_token, std=.02)
|
431 |
+
|
432 |
+
self.apply(self._init_weights)
|
433 |
+
self.fix_init_weight()
|
434 |
+
|
435 |
+
if isinstance(self.head, nn.Linear):
|
436 |
+
trunc_normal_(self.head.weight, std=.02)
|
437 |
+
self.head.weight.data.mul_(init_scale)
|
438 |
+
self.head.bias.data.mul_(init_scale)
|
439 |
+
|
440 |
+
if head_2mlp:
|
441 |
+
self.proj = nn.Linear(embed_dim, 512)
|
442 |
+
self.out_norm = norm_layer(512)
|
443 |
+
self.head_clip = nn.Linear(512, num_classes)
|
444 |
+
del self.head
|
445 |
+
|
446 |
+
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
|
447 |
+
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
|
448 |
+
|
449 |
+
self.grad_checkpointing = grad_checkpointing
|
450 |
+
|
451 |
+
def fix_init_weight(self):
|
452 |
+
def rescale(param, layer_id):
|
453 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
454 |
+
|
455 |
+
for layer_id, layer in enumerate(self.blocks):
|
456 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
457 |
+
if self.naiveswiglu:
|
458 |
+
rescale(layer.mlp.w3.weight.data, layer_id + 1)
|
459 |
+
else:
|
460 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
461 |
+
|
462 |
+
def get_cast_dtype(self) -> torch.dtype:
|
463 |
+
return self.blocks[0].mlp.fc2.weight.dtype
|
464 |
+
|
465 |
+
def _init_weights(self, m):
|
466 |
+
if isinstance(m, nn.Linear):
|
467 |
+
trunc_normal_(m.weight, std=.02)
|
468 |
+
if m.bias is not None:
|
469 |
+
nn.init.constant_(m.bias, 0)
|
470 |
+
elif isinstance(m, nn.LayerNorm):
|
471 |
+
nn.init.constant_(m.bias, 0)
|
472 |
+
nn.init.constant_(m.weight, 1.0)
|
473 |
+
|
474 |
+
def get_num_layers(self):
|
475 |
+
return len(self.blocks)
|
476 |
+
|
477 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
478 |
+
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
|
479 |
+
for param in self.parameters():
|
480 |
+
param.requires_grad = False
|
481 |
+
|
482 |
+
@torch.jit.ignore
|
483 |
+
def set_grad_checkpointing(self, enable=True):
|
484 |
+
self.grad_checkpointing = enable
|
485 |
+
|
486 |
+
@torch.jit.ignore
|
487 |
+
def no_weight_decay(self):
|
488 |
+
return {'pos_embed', 'cls_token'}
|
489 |
+
|
490 |
+
def get_classifier(self):
|
491 |
+
return self.head
|
492 |
+
|
493 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
494 |
+
self.num_classes = num_classes
|
495 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
496 |
+
|
497 |
+
def forward_features(self, x, return_all_features=False):
|
498 |
+
|
499 |
+
x = self.patch_embed(x)
|
500 |
+
batch_size, seq_len, _ = x.size()
|
501 |
+
|
502 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
503 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
504 |
+
if self.pos_embed is not None:
|
505 |
+
x = x + self.pos_embed
|
506 |
+
x = self.pos_drop(x)
|
507 |
+
|
508 |
+
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
509 |
+
if os.getenv('RoPE') == '1':
|
510 |
+
if self.training and not isinstance(self.patch_dropout, nn.Identity):
|
511 |
+
x, patch_indices_keep = self.patch_dropout(x)
|
512 |
+
self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
|
513 |
+
else:
|
514 |
+
self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
|
515 |
+
x = self.patch_dropout(x)
|
516 |
+
else:
|
517 |
+
x = self.patch_dropout(x)
|
518 |
+
|
519 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
520 |
+
for blk in self.blocks:
|
521 |
+
if self.grad_checkpointing:
|
522 |
+
x = checkpoint(blk, x, (rel_pos_bias,))
|
523 |
+
else:
|
524 |
+
x = blk(x, rel_pos_bias=rel_pos_bias)
|
525 |
+
|
526 |
+
if not return_all_features:
|
527 |
+
x = self.norm(x)
|
528 |
+
if self.fc_norm is not None:
|
529 |
+
return self.fc_norm(x.mean(1))
|
530 |
+
else:
|
531 |
+
return x[:, 0]
|
532 |
+
return x
|
533 |
+
|
534 |
+
def forward(self, x, return_all_features=False):
|
535 |
+
if return_all_features:
|
536 |
+
return self.forward_features(x, return_all_features)
|
537 |
+
x = self.forward_features(x)
|
538 |
+
if self.head_2mlp:
|
539 |
+
x = self.proj(x)
|
540 |
+
x = self.out_norm(x)
|
541 |
+
x = self.head_clip(x)
|
542 |
+
else:
|
543 |
+
x = self.head(x)
|
544 |
+
return x
|
545 |
+
|
546 |
+
|
547 |
+
def eva_base_p16():
|
548 |
+
model = EVAVisionTransformer(
|
549 |
+
depth=12, embed_dim=768, num_heads=12, mlp_ratio=2.6667, num_classes=1024,
|
550 |
+
xattn=True, rope=True, intp_freq=True, naiveswiglu=True,
|
551 |
+
subln=True, use_mean_pooling=False, qkv_bias=True,
|
552 |
+
norm_layer=partial(LayerNorm, eps=1e-6)
|
553 |
+
)
|
554 |
+
return model
|
555 |
+
|
556 |
+
def eva_large_p14_336():
|
557 |
+
model = EVAVisionTransformer(
|
558 |
+
img_size=336,
|
559 |
+
depth=24, embed_dim=1024, num_heads=16, mlp_ratio=2.6667,patch_size=14, num_classes=1024,
|
560 |
+
xattn=True, rope=True, intp_freq=True, naiveswiglu=True,
|
561 |
+
subln=True, use_mean_pooling=False, qkv_bias=True,
|
562 |
+
norm_layer=partial(LayerNorm, eps=1e-6)
|
563 |
+
)
|
564 |
+
return model
|
565 |
+
|
566 |
+
|
567 |
+
def eva_giant_p14_336():
|
568 |
+
model = EVAVisionTransformer(
|
569 |
+
img_size=336,
|
570 |
+
depth=40, embed_dim=1408, num_heads=16, mlp_ratio=2.909133333333333,patch_size=14, num_classes=1024,
|
571 |
+
xattn=True, rope=True, intp_freq=True, naiveswiglu=True,
|
572 |
+
subln=True, use_mean_pooling=False, qkv_bias=True,
|
573 |
+
norm_layer=partial(LayerNorm, eps=1e-6)
|
574 |
+
)
|
575 |
+
return model
|
eva_vit_model/rope.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import pi
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
import logging
|
6 |
+
|
7 |
+
def broadcat(tensors, dim = -1):
|
8 |
+
num_tensors = len(tensors)
|
9 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
10 |
+
assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
|
11 |
+
shape_len = list(shape_lens)[0]
|
12 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
13 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
14 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
15 |
+
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
|
16 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
17 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
18 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
19 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
20 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
21 |
+
return torch.cat(tensors, dim = dim)
|
22 |
+
|
23 |
+
def rotate_half(x):
|
24 |
+
x = rearrange(x, '... (d r) -> ... d r', r = 2)
|
25 |
+
x1, x2 = x.unbind(dim = -1)
|
26 |
+
x = torch.stack((-x2, x1), dim = -1)
|
27 |
+
return rearrange(x, '... d r -> ... (d r)')
|
28 |
+
|
29 |
+
|
30 |
+
class VisionRotaryEmbedding(nn.Module):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
dim,
|
34 |
+
pt_seq_len,
|
35 |
+
ft_seq_len=None,
|
36 |
+
custom_freqs = None,
|
37 |
+
freqs_for = 'lang',
|
38 |
+
theta = 10000,
|
39 |
+
max_freq = 10,
|
40 |
+
num_freqs = 1,
|
41 |
+
):
|
42 |
+
super().__init__()
|
43 |
+
if custom_freqs:
|
44 |
+
freqs = custom_freqs
|
45 |
+
elif freqs_for == 'lang':
|
46 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
47 |
+
elif freqs_for == 'pixel':
|
48 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
49 |
+
elif freqs_for == 'constant':
|
50 |
+
freqs = torch.ones(num_freqs).float()
|
51 |
+
else:
|
52 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
53 |
+
|
54 |
+
if ft_seq_len is None: ft_seq_len = pt_seq_len
|
55 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
56 |
+
|
57 |
+
freqs_h = torch.einsum('..., f -> ... f', t, freqs)
|
58 |
+
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
|
59 |
+
|
60 |
+
freqs_w = torch.einsum('..., f -> ... f', t, freqs)
|
61 |
+
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
|
62 |
+
|
63 |
+
freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
|
64 |
+
|
65 |
+
self.register_buffer("freqs_cos", freqs.cos())
|
66 |
+
self.register_buffer("freqs_sin", freqs.sin())
|
67 |
+
|
68 |
+
logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
|
69 |
+
|
70 |
+
def forward(self, t, start_index = 0):
|
71 |
+
rot_dim = self.freqs_cos.shape[-1]
|
72 |
+
end_index = start_index + rot_dim
|
73 |
+
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
|
74 |
+
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
|
75 |
+
t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
|
76 |
+
|
77 |
+
return torch.cat((t_left, t, t_right), dim = -1)
|
78 |
+
|
79 |
+
class VisionRotaryEmbeddingFast(nn.Module):
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
dim,
|
83 |
+
pt_seq_len,
|
84 |
+
ft_seq_len=None,
|
85 |
+
custom_freqs = None,
|
86 |
+
freqs_for = 'lang',
|
87 |
+
theta = 10000,
|
88 |
+
max_freq = 10,
|
89 |
+
num_freqs = 1,
|
90 |
+
patch_dropout = 0.
|
91 |
+
):
|
92 |
+
super().__init__()
|
93 |
+
if custom_freqs:
|
94 |
+
freqs = custom_freqs
|
95 |
+
elif freqs_for == 'lang':
|
96 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
97 |
+
elif freqs_for == 'pixel':
|
98 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
99 |
+
elif freqs_for == 'constant':
|
100 |
+
freqs = torch.ones(num_freqs).float()
|
101 |
+
else:
|
102 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
103 |
+
|
104 |
+
if ft_seq_len is None: ft_seq_len = pt_seq_len
|
105 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
106 |
+
|
107 |
+
freqs = torch.einsum('..., f -> ... f', t, freqs)
|
108 |
+
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
|
109 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
|
110 |
+
|
111 |
+
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
|
112 |
+
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
|
113 |
+
|
114 |
+
self.patch_dropout = patch_dropout
|
115 |
+
|
116 |
+
self.register_buffer("freqs_cos", freqs_cos)
|
117 |
+
self.register_buffer("freqs_sin", freqs_sin)
|
118 |
+
|
119 |
+
logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
|
120 |
+
|
121 |
+
def forward(self, t, patch_indices_keep=None):
|
122 |
+
if patch_indices_keep is not None:
|
123 |
+
batch = t.size()[0]
|
124 |
+
batch_indices = torch.arange(batch)
|
125 |
+
batch_indices = batch_indices[..., None]
|
126 |
+
|
127 |
+
freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
|
128 |
+
freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
|
129 |
+
|
130 |
+
freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
|
131 |
+
freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
|
132 |
+
freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
|
133 |
+
freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
|
134 |
+
|
135 |
+
return t * freqs_cos + rotate_half(t) * freqs_sin
|
136 |
+
|
137 |
+
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
|
eva_vit_model/transformer.py
ADDED
@@ -0,0 +1,625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
from collections import OrderedDict
|
4 |
+
import math
|
5 |
+
from typing import Callable, Optional, Sequence
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
|
11 |
+
if os.getenv('ENV_TYPE') == 'deepspeed':
|
12 |
+
try:
|
13 |
+
import deepspeed
|
14 |
+
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
|
15 |
+
except:
|
16 |
+
print("Please 'pip install deepspeed'")
|
17 |
+
deepspeed = None
|
18 |
+
from torch.utils.checkpoint import checkpoint
|
19 |
+
else:
|
20 |
+
from torch.utils.checkpoint import checkpoint
|
21 |
+
|
22 |
+
try:
|
23 |
+
import xformers.ops as xops
|
24 |
+
except ImportError:
|
25 |
+
xops = None
|
26 |
+
print("Please 'pip install xformers'")
|
27 |
+
|
28 |
+
class LayerNormFp32(nn.LayerNorm):
|
29 |
+
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
|
30 |
+
def __init__(self, *args, **kwargs):
|
31 |
+
super().__init__(*args, **kwargs)
|
32 |
+
|
33 |
+
def forward(self, x: torch.Tensor):
|
34 |
+
output = F.layer_norm(
|
35 |
+
x.float(),
|
36 |
+
self.normalized_shape,
|
37 |
+
self.weight.float() if self.weight is not None else None,
|
38 |
+
self.bias.float() if self.bias is not None else None,
|
39 |
+
self.eps,
|
40 |
+
)
|
41 |
+
return output.type_as(x)
|
42 |
+
|
43 |
+
|
44 |
+
class LayerNorm(nn.LayerNorm):
|
45 |
+
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
|
46 |
+
|
47 |
+
def forward(self, x: torch.Tensor):
|
48 |
+
orig_type = x.dtype
|
49 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
50 |
+
return x.to(orig_type)
|
51 |
+
|
52 |
+
class QuickGELU(nn.Module):
|
53 |
+
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
|
54 |
+
def forward(self, x: torch.Tensor):
|
55 |
+
return x * torch.sigmoid(1.702 * x)
|
56 |
+
|
57 |
+
|
58 |
+
class LayerScale(nn.Module):
|
59 |
+
def __init__(self, dim, init_values=1e-5, inplace=False):
|
60 |
+
super().__init__()
|
61 |
+
self.inplace = inplace
|
62 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
66 |
+
|
67 |
+
class PatchDropout(nn.Module):
|
68 |
+
"""
|
69 |
+
https://arxiv.org/abs/2212.00794
|
70 |
+
"""
|
71 |
+
|
72 |
+
def __init__(self, prob, exclude_first_token=True):
|
73 |
+
super().__init__()
|
74 |
+
assert 0 <= prob < 1.
|
75 |
+
self.prob = prob
|
76 |
+
self.exclude_first_token = exclude_first_token # exclude CLS token
|
77 |
+
logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
if not self.training or self.prob == 0.:
|
81 |
+
return x
|
82 |
+
|
83 |
+
if self.exclude_first_token:
|
84 |
+
cls_tokens, x = x[:, :1], x[:, 1:]
|
85 |
+
else:
|
86 |
+
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
87 |
+
|
88 |
+
batch = x.size()[0]
|
89 |
+
num_tokens = x.size()[1]
|
90 |
+
|
91 |
+
batch_indices = torch.arange(batch)
|
92 |
+
batch_indices = batch_indices[..., None]
|
93 |
+
|
94 |
+
keep_prob = 1 - self.prob
|
95 |
+
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
96 |
+
|
97 |
+
rand = torch.randn(batch, num_tokens)
|
98 |
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
99 |
+
|
100 |
+
x = x[batch_indices, patch_indices_keep]
|
101 |
+
|
102 |
+
if self.exclude_first_token:
|
103 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
104 |
+
|
105 |
+
if self.training and os.getenv('RoPE') == '1':
|
106 |
+
return x, patch_indices_keep
|
107 |
+
|
108 |
+
return x
|
109 |
+
|
110 |
+
|
111 |
+
def _in_projection_packed(
|
112 |
+
q: torch.Tensor,
|
113 |
+
k: torch.Tensor,
|
114 |
+
v: torch.Tensor,
|
115 |
+
w: torch.Tensor,
|
116 |
+
b: Optional[torch.Tensor] = None,
|
117 |
+
):
|
118 |
+
"""
|
119 |
+
https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726
|
120 |
+
"""
|
121 |
+
E = q.size(-1)
|
122 |
+
if k is v:
|
123 |
+
if q is k:
|
124 |
+
# self-attention
|
125 |
+
return F.linear(q, w, b).chunk(3, dim=-1)
|
126 |
+
else:
|
127 |
+
# encoder-decoder attention
|
128 |
+
w_q, w_kv = w.split([E, E * 2])
|
129 |
+
if b is None:
|
130 |
+
b_q = b_kv = None
|
131 |
+
else:
|
132 |
+
b_q, b_kv = b.split([E, E * 2])
|
133 |
+
return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
|
134 |
+
else:
|
135 |
+
w_q, w_k, w_v = w.chunk(3)
|
136 |
+
if b is None:
|
137 |
+
b_q = b_k = b_v = None
|
138 |
+
else:
|
139 |
+
b_q, b_k, b_v = b.chunk(3)
|
140 |
+
return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
|
141 |
+
|
142 |
+
class Attention(nn.Module):
|
143 |
+
def __init__(
|
144 |
+
self,
|
145 |
+
dim,
|
146 |
+
num_heads=8,
|
147 |
+
qkv_bias=True,
|
148 |
+
scaled_cosine=False,
|
149 |
+
scale_heads=False,
|
150 |
+
logit_scale_max=math.log(1. / 0.01),
|
151 |
+
attn_drop=0.,
|
152 |
+
proj_drop=0.,
|
153 |
+
xattn=False,
|
154 |
+
rope=False
|
155 |
+
):
|
156 |
+
super().__init__()
|
157 |
+
self.scaled_cosine = scaled_cosine
|
158 |
+
self.scale_heads = scale_heads
|
159 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
160 |
+
self.num_heads = num_heads
|
161 |
+
self.head_dim = dim // num_heads
|
162 |
+
self.scale = self.head_dim ** -0.5
|
163 |
+
self.logit_scale_max = logit_scale_max
|
164 |
+
|
165 |
+
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
|
166 |
+
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
167 |
+
if qkv_bias:
|
168 |
+
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
169 |
+
else:
|
170 |
+
self.in_proj_bias = None
|
171 |
+
|
172 |
+
if self.scaled_cosine:
|
173 |
+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
|
174 |
+
else:
|
175 |
+
self.logit_scale = None
|
176 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
177 |
+
if self.scale_heads:
|
178 |
+
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
|
179 |
+
else:
|
180 |
+
self.head_scale = None
|
181 |
+
self.out_proj = nn.Linear(dim, dim)
|
182 |
+
self.out_drop = nn.Dropout(proj_drop)
|
183 |
+
self.xattn = xattn
|
184 |
+
self.xattn_drop = attn_drop
|
185 |
+
self.rope = rope
|
186 |
+
|
187 |
+
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
188 |
+
L, N, C = x.shape
|
189 |
+
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
|
190 |
+
if self.xattn:
|
191 |
+
q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
|
192 |
+
k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
|
193 |
+
v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
|
194 |
+
|
195 |
+
x = xops.memory_efficient_attention(
|
196 |
+
q, k, v,
|
197 |
+
p=self.xattn_drop,
|
198 |
+
scale=self.scale if self.logit_scale is None else None,
|
199 |
+
attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None,
|
200 |
+
)
|
201 |
+
else:
|
202 |
+
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
203 |
+
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
204 |
+
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
205 |
+
|
206 |
+
if self.logit_scale is not None:
|
207 |
+
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
|
208 |
+
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
|
209 |
+
attn = attn.view(N, self.num_heads, L, L) * logit_scale
|
210 |
+
attn = attn.view(-1, L, L)
|
211 |
+
else:
|
212 |
+
q = q * self.scale
|
213 |
+
attn = torch.bmm(q, k.transpose(-1, -2))
|
214 |
+
|
215 |
+
if attn_mask is not None:
|
216 |
+
if attn_mask.dtype == torch.bool:
|
217 |
+
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
218 |
+
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
219 |
+
attn_mask = new_attn_mask
|
220 |
+
attn += attn_mask
|
221 |
+
|
222 |
+
attn = attn.softmax(dim=-1)
|
223 |
+
attn = self.attn_drop(attn)
|
224 |
+
|
225 |
+
x = torch.bmm(attn, v)
|
226 |
+
|
227 |
+
if self.head_scale is not None:
|
228 |
+
x = x.view(N, self.num_heads, L, C) * self.head_scale
|
229 |
+
x = x.view(-1, L, C)
|
230 |
+
x = x.transpose(0, 1).reshape(L, N, C)
|
231 |
+
x = self.out_proj(x)
|
232 |
+
x = self.out_drop(x)
|
233 |
+
return x
|
234 |
+
|
235 |
+
class CustomAttention(nn.Module):
|
236 |
+
def __init__(
|
237 |
+
self,
|
238 |
+
dim,
|
239 |
+
num_heads=8,
|
240 |
+
qkv_bias=True,
|
241 |
+
scaled_cosine=True,
|
242 |
+
scale_heads=False,
|
243 |
+
logit_scale_max=math.log(1. / 0.01),
|
244 |
+
attn_drop=0.,
|
245 |
+
proj_drop=0.,
|
246 |
+
xattn=False
|
247 |
+
):
|
248 |
+
super().__init__()
|
249 |
+
self.scaled_cosine = scaled_cosine
|
250 |
+
self.scale_heads = scale_heads
|
251 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
252 |
+
self.num_heads = num_heads
|
253 |
+
self.head_dim = dim // num_heads
|
254 |
+
self.scale = self.head_dim ** -0.5
|
255 |
+
self.logit_scale_max = logit_scale_max
|
256 |
+
|
257 |
+
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
|
258 |
+
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
259 |
+
if qkv_bias:
|
260 |
+
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
261 |
+
else:
|
262 |
+
self.in_proj_bias = None
|
263 |
+
|
264 |
+
if self.scaled_cosine:
|
265 |
+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
|
266 |
+
else:
|
267 |
+
self.logit_scale = None
|
268 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
269 |
+
if self.scale_heads:
|
270 |
+
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
|
271 |
+
else:
|
272 |
+
self.head_scale = None
|
273 |
+
self.out_proj = nn.Linear(dim, dim)
|
274 |
+
self.out_drop = nn.Dropout(proj_drop)
|
275 |
+
self.xattn = xattn
|
276 |
+
self.xattn_drop = attn_drop
|
277 |
+
|
278 |
+
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
279 |
+
q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
|
280 |
+
N_q, B_q, C_q = q.shape
|
281 |
+
N_k, B_k, C_k = k.shape
|
282 |
+
N_v, B_v, C_v = v.shape
|
283 |
+
if self.xattn:
|
284 |
+
# B, N, C -> B, N, num_heads, C
|
285 |
+
q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1)
|
286 |
+
k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1)
|
287 |
+
v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1)
|
288 |
+
|
289 |
+
x = xops.memory_efficient_attention(
|
290 |
+
q, k, v,
|
291 |
+
p=self.xattn_drop,
|
292 |
+
scale=self.scale if self.logit_scale is None else None,
|
293 |
+
attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None
|
294 |
+
)
|
295 |
+
else:
|
296 |
+
# B*H, L, C
|
297 |
+
q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1)
|
298 |
+
k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1)
|
299 |
+
v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1)
|
300 |
+
|
301 |
+
if self.logit_scale is not None:
|
302 |
+
# B*H, N_q, N_k
|
303 |
+
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
|
304 |
+
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
|
305 |
+
attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale
|
306 |
+
attn = attn.view(-1, N_q, N_k)
|
307 |
+
else:
|
308 |
+
q = q * self.scale
|
309 |
+
attn = torch.bmm(q, k.transpose(-1, -2))
|
310 |
+
|
311 |
+
if attn_mask is not None:
|
312 |
+
if attn_mask.dtype == torch.bool:
|
313 |
+
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
314 |
+
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
315 |
+
attn_mask = new_attn_mask
|
316 |
+
attn += attn_mask
|
317 |
+
|
318 |
+
attn = attn.softmax(dim=-1)
|
319 |
+
attn = self.attn_drop(attn)
|
320 |
+
|
321 |
+
x = torch.bmm(attn, v)
|
322 |
+
|
323 |
+
if self.head_scale is not None:
|
324 |
+
x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale
|
325 |
+
x = x.view(-1, N_q, C_q)
|
326 |
+
x = x.transpose(0, 1).reshape(N_q, B_q, C_q)
|
327 |
+
x = self.out_proj(x)
|
328 |
+
x = self.out_drop(x)
|
329 |
+
return x
|
330 |
+
|
331 |
+
class CustomResidualAttentionBlock(nn.Module):
|
332 |
+
def __init__(
|
333 |
+
self,
|
334 |
+
d_model: int,
|
335 |
+
n_head: int,
|
336 |
+
mlp_ratio: float = 4.0,
|
337 |
+
ls_init_value: float = None,
|
338 |
+
act_layer: Callable = nn.GELU,
|
339 |
+
norm_layer: Callable = LayerNorm,
|
340 |
+
scale_cosine_attn: bool = False,
|
341 |
+
scale_heads: bool = False,
|
342 |
+
scale_attn: bool = False,
|
343 |
+
scale_fc: bool = False,
|
344 |
+
cross_attn: bool = False,
|
345 |
+
xattn: bool = False,
|
346 |
+
):
|
347 |
+
super().__init__()
|
348 |
+
|
349 |
+
self.ln_1 = norm_layer(d_model)
|
350 |
+
self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1
|
351 |
+
self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1
|
352 |
+
self.attn = CustomAttention(
|
353 |
+
d_model, n_head,
|
354 |
+
qkv_bias=True,
|
355 |
+
attn_drop=0.,
|
356 |
+
proj_drop=0.,
|
357 |
+
scaled_cosine=scale_cosine_attn,
|
358 |
+
scale_heads=scale_heads,
|
359 |
+
xattn=xattn
|
360 |
+
)
|
361 |
+
|
362 |
+
self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
|
363 |
+
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
364 |
+
|
365 |
+
self.ln_2 = norm_layer(d_model)
|
366 |
+
mlp_width = int(d_model * mlp_ratio)
|
367 |
+
self.mlp = nn.Sequential(OrderedDict([
|
368 |
+
("c_fc", nn.Linear(d_model, mlp_width)),
|
369 |
+
('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
|
370 |
+
("gelu", act_layer()),
|
371 |
+
("c_proj", nn.Linear(mlp_width, d_model))
|
372 |
+
]))
|
373 |
+
|
374 |
+
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
375 |
+
|
376 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
377 |
+
q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask)))
|
378 |
+
q = q + self.ls_2(self.mlp(self.ln_2(q)))
|
379 |
+
return q
|
380 |
+
|
381 |
+
class CustomTransformer(nn.Module):
|
382 |
+
def __init__(
|
383 |
+
self,
|
384 |
+
width: int,
|
385 |
+
layers: int,
|
386 |
+
heads: int,
|
387 |
+
mlp_ratio: float = 4.0,
|
388 |
+
ls_init_value: float = None,
|
389 |
+
act_layer: Callable = nn.GELU,
|
390 |
+
norm_layer: Callable = LayerNorm,
|
391 |
+
scale_cosine_attn: bool = True,
|
392 |
+
scale_heads: bool = False,
|
393 |
+
scale_attn: bool = False,
|
394 |
+
scale_fc: bool = False,
|
395 |
+
cross_attn: bool = False,
|
396 |
+
xattn: bool = False,
|
397 |
+
):
|
398 |
+
super().__init__()
|
399 |
+
self.width = width
|
400 |
+
self.layers = layers
|
401 |
+
self.grad_checkpointing = False
|
402 |
+
self.xattn = xattn
|
403 |
+
|
404 |
+
self.resblocks = nn.ModuleList([
|
405 |
+
CustomResidualAttentionBlock(
|
406 |
+
width,
|
407 |
+
heads,
|
408 |
+
mlp_ratio,
|
409 |
+
ls_init_value=ls_init_value,
|
410 |
+
act_layer=act_layer,
|
411 |
+
norm_layer=norm_layer,
|
412 |
+
scale_cosine_attn=scale_cosine_attn,
|
413 |
+
scale_heads=scale_heads,
|
414 |
+
scale_attn=scale_attn,
|
415 |
+
scale_fc=scale_fc,
|
416 |
+
cross_attn=cross_attn,
|
417 |
+
xattn=xattn)
|
418 |
+
for _ in range(layers)
|
419 |
+
])
|
420 |
+
|
421 |
+
def get_cast_dtype(self) -> torch.dtype:
|
422 |
+
return self.resblocks[0].mlp.c_fc.weight.dtype
|
423 |
+
|
424 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None):
|
425 |
+
if k is None and v is None:
|
426 |
+
k = v = q
|
427 |
+
for r in self.resblocks:
|
428 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
429 |
+
q = checkpoint(r, q, k, v, attn_mask)
|
430 |
+
else:
|
431 |
+
q = r(q, k, v, attn_mask=attn_mask)
|
432 |
+
return q
|
433 |
+
|
434 |
+
|
435 |
+
class ResidualAttentionBlock(nn.Module):
|
436 |
+
def __init__(
|
437 |
+
self,
|
438 |
+
d_model: int,
|
439 |
+
n_head: int,
|
440 |
+
mlp_ratio: float = 4.0,
|
441 |
+
ls_init_value: float = None,
|
442 |
+
act_layer: Callable = nn.GELU,
|
443 |
+
norm_layer: Callable = LayerNorm,
|
444 |
+
xattn: bool = False,
|
445 |
+
):
|
446 |
+
super().__init__()
|
447 |
+
|
448 |
+
self.ln_1 = norm_layer(d_model)
|
449 |
+
if xattn:
|
450 |
+
self.attn = Attention(d_model, n_head, xattn=True)
|
451 |
+
else:
|
452 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
453 |
+
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
454 |
+
|
455 |
+
self.ln_2 = norm_layer(d_model)
|
456 |
+
mlp_width = int(d_model * mlp_ratio)
|
457 |
+
self.mlp = nn.Sequential(OrderedDict([
|
458 |
+
("c_fc", nn.Linear(d_model, mlp_width)),
|
459 |
+
("gelu", act_layer()),
|
460 |
+
("c_proj", nn.Linear(mlp_width, d_model))
|
461 |
+
]))
|
462 |
+
|
463 |
+
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
464 |
+
self.xattn = xattn
|
465 |
+
|
466 |
+
def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
467 |
+
attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
|
468 |
+
if self.xattn:
|
469 |
+
return self.attn(x, attn_mask=attn_mask)
|
470 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
|
471 |
+
|
472 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
473 |
+
x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask))
|
474 |
+
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
475 |
+
return x
|
476 |
+
|
477 |
+
class Transformer(nn.Module):
|
478 |
+
def __init__(
|
479 |
+
self,
|
480 |
+
width: int,
|
481 |
+
layers: int,
|
482 |
+
heads: int,
|
483 |
+
mlp_ratio: float = 4.0,
|
484 |
+
ls_init_value: float = None,
|
485 |
+
act_layer: Callable = nn.GELU,
|
486 |
+
norm_layer: Callable = LayerNorm,
|
487 |
+
xattn: bool = False,
|
488 |
+
):
|
489 |
+
super().__init__()
|
490 |
+
self.width = width
|
491 |
+
self.layers = layers
|
492 |
+
self.grad_checkpointing = False
|
493 |
+
|
494 |
+
self.resblocks = nn.ModuleList([
|
495 |
+
ResidualAttentionBlock(
|
496 |
+
width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn)
|
497 |
+
for _ in range(layers)
|
498 |
+
])
|
499 |
+
|
500 |
+
def get_cast_dtype(self) -> torch.dtype:
|
501 |
+
return self.resblocks[0].mlp.c_fc.weight.dtype
|
502 |
+
|
503 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
504 |
+
for r in self.resblocks:
|
505 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
506 |
+
x = checkpoint(r, x, attn_mask)
|
507 |
+
else:
|
508 |
+
x = r(x, attn_mask=attn_mask)
|
509 |
+
return x
|
510 |
+
|
511 |
+
|
512 |
+
class TextTransformer(nn.Module):
|
513 |
+
def __init__(
|
514 |
+
self,
|
515 |
+
context_length: int = 77,
|
516 |
+
vocab_size: int = 49408,
|
517 |
+
width: int = 512,
|
518 |
+
heads: int = 8,
|
519 |
+
layers: int = 12,
|
520 |
+
ls_init_value: float = None,
|
521 |
+
output_dim: int = 512,
|
522 |
+
act_layer: Callable = nn.GELU,
|
523 |
+
norm_layer: Callable = LayerNorm,
|
524 |
+
xattn: bool= False,
|
525 |
+
attn_mask: bool = True
|
526 |
+
):
|
527 |
+
super().__init__()
|
528 |
+
self.context_length = context_length
|
529 |
+
self.vocab_size = vocab_size
|
530 |
+
self.width = width
|
531 |
+
self.output_dim = output_dim
|
532 |
+
|
533 |
+
self.token_embedding = nn.Embedding(vocab_size, width)
|
534 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
|
535 |
+
self.transformer = Transformer(
|
536 |
+
width=width,
|
537 |
+
layers=layers,
|
538 |
+
heads=heads,
|
539 |
+
ls_init_value=ls_init_value,
|
540 |
+
act_layer=act_layer,
|
541 |
+
norm_layer=norm_layer,
|
542 |
+
xattn=xattn
|
543 |
+
)
|
544 |
+
|
545 |
+
self.xattn = xattn
|
546 |
+
self.ln_final = norm_layer(width)
|
547 |
+
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
548 |
+
|
549 |
+
if attn_mask:
|
550 |
+
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
|
551 |
+
else:
|
552 |
+
self.attn_mask = None
|
553 |
+
|
554 |
+
self.init_parameters()
|
555 |
+
|
556 |
+
def init_parameters(self):
|
557 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
558 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
559 |
+
|
560 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
561 |
+
attn_std = self.transformer.width ** -0.5
|
562 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
563 |
+
for block in self.transformer.resblocks:
|
564 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
565 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
566 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
567 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
568 |
+
|
569 |
+
if self.text_projection is not None:
|
570 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
571 |
+
|
572 |
+
def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
|
573 |
+
if not unlocked_layers: # full freezing
|
574 |
+
for n, p in self.named_parameters():
|
575 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
576 |
+
else:
|
577 |
+
raise ValueError("Not support partial freeze")
|
578 |
+
|
579 |
+
@torch.jit.ignore
|
580 |
+
def set_grad_checkpointing(self, enable=True):
|
581 |
+
self.transformer.grad_checkpointing = enable
|
582 |
+
|
583 |
+
@torch.jit.ignore
|
584 |
+
def no_weight_decay(self):
|
585 |
+
# return {'positional_embedding', 'token_embedding'}
|
586 |
+
return {'positional_embedding'}
|
587 |
+
|
588 |
+
def get_num_layers(self):
|
589 |
+
return self.transformer.layers
|
590 |
+
|
591 |
+
def build_attention_mask(self):
|
592 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
593 |
+
# pytorch uses additive attention mask; fill with -inf
|
594 |
+
mask = torch.empty(self.context_length, self.context_length)
|
595 |
+
mask.fill_(float("-inf"))
|
596 |
+
mask.triu_(1) # zero out the lower diagonal
|
597 |
+
return mask
|
598 |
+
|
599 |
+
def forward(self, text, return_all_features: bool=False):
|
600 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
601 |
+
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
602 |
+
|
603 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
604 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
605 |
+
x = self.transformer(x, attn_mask=self.attn_mask)
|
606 |
+
# x = self.transformer(x) # no attention mask is applied
|
607 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
608 |
+
x = self.ln_final(x)
|
609 |
+
|
610 |
+
if not return_all_features:
|
611 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
612 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
613 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
614 |
+
return x
|
615 |
+
|
616 |
+
|
617 |
+
def text_transformer():
|
618 |
+
model = TextTransformer(
|
619 |
+
width=1024,
|
620 |
+
output_dim=1024,
|
621 |
+
heads=16,
|
622 |
+
layers=24,
|
623 |
+
xattn=True
|
624 |
+
)
|
625 |
+
return model
|
eva_vit_model/uta_clip.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from . import eva_vit
|
8 |
+
from .transformer import text_transformer
|
9 |
+
|
10 |
+
class CLIP(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
vision_model: str = 'eva_base_p16',
|
14 |
+
):
|
15 |
+
super().__init__()
|
16 |
+
self.visual = eva_vit.__dict__[vision_model]()
|
17 |
+
self.text = text_transformer()
|
18 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
19 |
+
|
20 |
+
def encode_image(self, image, normalize: bool = False):
|
21 |
+
features = self.visual(image)
|
22 |
+
return F.normalize(features, dim=-1) if normalize else features
|
23 |
+
|
24 |
+
def encode_text(self, text, normalize: bool = False):
|
25 |
+
features = self.text(text)
|
26 |
+
return F.normalize(features, dim=-1) if normalize else features
|
27 |
+
|
28 |
+
def forward(self, image, text):
|
29 |
+
image_features = self.encode_image(image, normalize=True)
|
30 |
+
text_features = self.encode_text(text, normalize=True)
|
31 |
+
return image_features, text_features, self.logit_scale.exp()
|
imagenet_zeroshot_data.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
imagenet_classnames = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray",
|
4 |
+
"stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco",
|
5 |
+
"indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper",
|
6 |
+
"kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander",
|
7 |
+
"smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog",
|
8 |
+
"tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin",
|
9 |
+
"box turtle", "banded gecko", "green iguana", "Carolina anole",
|
10 |
+
"desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard",
|
11 |
+
"Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile",
|
12 |
+
"American alligator", "triceratops", "worm snake", "ring-necked snake",
|
13 |
+
"eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake",
|
14 |
+
"vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra",
|
15 |
+
"green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake",
|
16 |
+
"sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider",
|
17 |
+
"barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider",
|
18 |
+
"tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl",
|
19 |
+
"quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet",
|
20 |
+
"coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck",
|
21 |
+
"red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby",
|
22 |
+
"koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch",
|
23 |
+
"snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab",
|
24 |
+
"fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab",
|
25 |
+
"isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron",
|
26 |
+
"great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot",
|
27 |
+
"bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher",
|
28 |
+
"pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion",
|
29 |
+
"Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel",
|
30 |
+
"Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle",
|
31 |
+
"Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound",
|
32 |
+
"English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound",
|
33 |
+
"Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound",
|
34 |
+
"Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier",
|
35 |
+
"Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier",
|
36 |
+
"Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier",
|
37 |
+
"Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier",
|
38 |
+
"Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer",
|
39 |
+
"Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier",
|
40 |
+
"Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier",
|
41 |
+
"Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever",
|
42 |
+
"Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla",
|
43 |
+
"English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel",
|
44 |
+
"English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel",
|
45 |
+
"Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard",
|
46 |
+
"Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie",
|
47 |
+
"Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann",
|
48 |
+
"Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog",
|
49 |
+
"Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff",
|
50 |
+
"French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky",
|
51 |
+
"Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog",
|
52 |
+
"Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon",
|
53 |
+
"Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle",
|
54 |
+
"Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf",
|
55 |
+
"red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox",
|
56 |
+
"kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat",
|
57 |
+
"Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger",
|
58 |
+
"cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose",
|
59 |
+
"meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle",
|
60 |
+
"dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper",
|
61 |
+
"cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper",
|
62 |
+
"lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly",
|
63 |
+
"monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly",
|
64 |
+
"starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit",
|
65 |
+
"hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse",
|
66 |
+
"zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison",
|
67 |
+
"ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)",
|
68 |
+
"gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat",
|
69 |
+
"black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan",
|
70 |
+
"gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque",
|
71 |
+
"langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin",
|
72 |
+
"howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey",
|
73 |
+
"ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda",
|
74 |
+
"giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish",
|
75 |
+
"sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown",
|
76 |
+
"accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
|
77 |
+
"amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle",
|
78 |
+
"backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo",
|
79 |
+
"baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel",
|
80 |
+
"wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel",
|
81 |
+
"bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)",
|
82 |
+
"beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini",
|
83 |
+
"ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet",
|
84 |
+
"bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra",
|
85 |
+
"breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest",
|
86 |
+
"high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe",
|
87 |
+
"can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton",
|
88 |
+
"car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran",
|
89 |
+
"CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw",
|
90 |
+
"storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking",
|
91 |
+
"church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker",
|
92 |
+
"coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard",
|
93 |
+
"candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot",
|
94 |
+
"cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed",
|
95 |
+
"Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer",
|
96 |
+
"rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table",
|
97 |
+
"dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig",
|
98 |
+
"drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar",
|
99 |
+
"electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder",
|
100 |
+
"feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute",
|
101 |
+
"folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed",
|
102 |
+
"freight car", "French horn", "frying pan", "fur coat", "garbage truck",
|
103 |
+
"gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola",
|
104 |
+
"gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine",
|
105 |
+
"hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer",
|
106 |
+
"handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet",
|
107 |
+
"holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar",
|
108 |
+
"horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep",
|
109 |
+
"T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat",
|
110 |
+
"ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library",
|
111 |
+
"lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion",
|
112 |
+
"music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag",
|
113 |
+
"mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask",
|
114 |
+
"matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone",
|
115 |
+
"microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile",
|
116 |
+
"mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor",
|
117 |
+
"moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa",
|
118 |
+
"mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail",
|
119 |
+
"neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina",
|
120 |
+
"odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart",
|
121 |
+
"oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush",
|
122 |
+
"pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench",
|
123 |
+
"parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case",
|
124 |
+
"pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube",
|
125 |
+
"picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball",
|
126 |
+
"pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag",
|
127 |
+
"plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho",
|
128 |
+
"pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug",
|
129 |
+
"printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill",
|
130 |
+
"quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel",
|
131 |
+
"recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator",
|
132 |
+
"remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser",
|
133 |
+
"rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal",
|
134 |
+
"sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard",
|
135 |
+
"CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store",
|
136 |
+
"shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap",
|
137 |
+
"shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door",
|
138 |
+
"slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
|
139 |
+
"solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater",
|
140 |
+
"space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight",
|
141 |
+
"stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf",
|
142 |
+
"stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa",
|
143 |
+
"submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge",
|
144 |
+
"mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe",
|
145 |
+
"table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball",
|
146 |
+
"thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof",
|
147 |
+
"toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store",
|
148 |
+
"tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod",
|
149 |
+
"triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard",
|
150 |
+
"umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling",
|
151 |
+
"velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball",
|
152 |
+
"waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink",
|
153 |
+
"washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle",
|
154 |
+
"hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing",
|
155 |
+
"wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website",
|
156 |
+
"comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu",
|
157 |
+
"plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette",
|
158 |
+
"bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli",
|
159 |
+
"cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber",
|
160 |
+
"artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange",
|
161 |
+
"lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate",
|
162 |
+
"hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito",
|
163 |
+
"red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef",
|
164 |
+
"geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player",
|
165 |
+
"bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
|
166 |
+
"rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom",
|
167 |
+
"earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
|
168 |
+
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
|
173 |
+
openai_imagenet_template = [
|
174 |
+
lambda c: f'a bad photo of a {c}.',
|
175 |
+
lambda c: f'a photo of many {c}.',
|
176 |
+
lambda c: f'a sculpture of a {c}.',
|
177 |
+
lambda c: f'a photo of the hard to see {c}.',
|
178 |
+
lambda c: f'a low resolution photo of the {c}.',
|
179 |
+
lambda c: f'a rendering of a {c}.',
|
180 |
+
lambda c: f'graffiti of a {c}.',
|
181 |
+
lambda c: f'a bad photo of the {c}.',
|
182 |
+
lambda c: f'a cropped photo of the {c}.',
|
183 |
+
lambda c: f'a tattoo of a {c}.',
|
184 |
+
lambda c: f'the embroidered {c}.',
|
185 |
+
lambda c: f'a photo of a hard to see {c}.',
|
186 |
+
lambda c: f'a bright photo of a {c}.',
|
187 |
+
lambda c: f'a photo of a clean {c}.',
|
188 |
+
lambda c: f'a photo of a dirty {c}.',
|
189 |
+
lambda c: f'a dark photo of the {c}.',
|
190 |
+
lambda c: f'a drawing of a {c}.',
|
191 |
+
lambda c: f'a photo of my {c}.',
|
192 |
+
lambda c: f'the plastic {c}.',
|
193 |
+
lambda c: f'a photo of the cool {c}.',
|
194 |
+
lambda c: f'a close-up photo of a {c}.',
|
195 |
+
lambda c: f'a black and white photo of the {c}.',
|
196 |
+
lambda c: f'a painting of the {c}.',
|
197 |
+
lambda c: f'a painting of a {c}.',
|
198 |
+
lambda c: f'a pixelated photo of the {c}.',
|
199 |
+
lambda c: f'a sculpture of the {c}.',
|
200 |
+
lambda c: f'a bright photo of the {c}.',
|
201 |
+
lambda c: f'a cropped photo of a {c}.',
|
202 |
+
lambda c: f'a plastic {c}.',
|
203 |
+
lambda c: f'a photo of the dirty {c}.',
|
204 |
+
lambda c: f'a jpeg corrupted photo of a {c}.',
|
205 |
+
lambda c: f'a blurry photo of the {c}.',
|
206 |
+
lambda c: f'a photo of the {c}.',
|
207 |
+
lambda c: f'a good photo of the {c}.',
|
208 |
+
lambda c: f'a rendering of the {c}.',
|
209 |
+
lambda c: f'a {c} in a video game.',
|
210 |
+
lambda c: f'a photo of one {c}.',
|
211 |
+
lambda c: f'a doodle of a {c}.',
|
212 |
+
lambda c: f'a close-up photo of the {c}.',
|
213 |
+
lambda c: f'a photo of a {c}.',
|
214 |
+
lambda c: f'the origami {c}.',
|
215 |
+
lambda c: f'the {c} in a video game.',
|
216 |
+
lambda c: f'a sketch of a {c}.',
|
217 |
+
lambda c: f'a doodle of the {c}.',
|
218 |
+
lambda c: f'a origami {c}.',
|
219 |
+
lambda c: f'a low resolution photo of a {c}.',
|
220 |
+
lambda c: f'the toy {c}.',
|
221 |
+
lambda c: f'a rendition of the {c}.',
|
222 |
+
lambda c: f'a photo of the clean {c}.',
|
223 |
+
lambda c: f'a photo of a large {c}.',
|
224 |
+
lambda c: f'a rendition of a {c}.',
|
225 |
+
lambda c: f'a photo of a nice {c}.',
|
226 |
+
lambda c: f'a photo of a weird {c}.',
|
227 |
+
lambda c: f'a blurry photo of a {c}.',
|
228 |
+
lambda c: f'a cartoon {c}.',
|
229 |
+
lambda c: f'art of a {c}.',
|
230 |
+
lambda c: f'a sketch of the {c}.',
|
231 |
+
lambda c: f'a embroidered {c}.',
|
232 |
+
lambda c: f'a pixelated photo of a {c}.',
|
233 |
+
lambda c: f'itap of the {c}.',
|
234 |
+
lambda c: f'a jpeg corrupted photo of the {c}.',
|
235 |
+
lambda c: f'a good photo of a {c}.',
|
236 |
+
lambda c: f'a plushie {c}.',
|
237 |
+
lambda c: f'a photo of the nice {c}.',
|
238 |
+
lambda c: f'a photo of the small {c}.',
|
239 |
+
lambda c: f'a photo of the weird {c}.',
|
240 |
+
lambda c: f'the cartoon {c}.',
|
241 |
+
lambda c: f'art of the {c}.',
|
242 |
+
lambda c: f'a drawing of the {c}.',
|
243 |
+
lambda c: f'a photo of the large {c}.',
|
244 |
+
lambda c: f'a black and white photo of a {c}.',
|
245 |
+
lambda c: f'the plushie {c}.',
|
246 |
+
lambda c: f'a dark photo of a {c}.',
|
247 |
+
lambda c: f'itap of a {c}.',
|
248 |
+
lambda c: f'graffiti of the {c}.',
|
249 |
+
lambda c: f'a toy {c}.',
|
250 |
+
lambda c: f'itap of my {c}.',
|
251 |
+
lambda c: f'a photo of a cool {c}.',
|
252 |
+
lambda c: f'a photo of a small {c}.',
|
253 |
+
lambda c: f'a tattoo of the {c}.',
|
254 |
+
]
|
imagenet_zeroshot_eval.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from tqdm import tqdm
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
import torchvision.datasets as datasets
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
12 |
+
|
13 |
+
import eva_vit_model
|
14 |
+
from eva_vit_model import CLIP
|
15 |
+
from open_clip.tokenizer import tokenize
|
16 |
+
from imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template
|
17 |
+
|
18 |
+
|
19 |
+
def main(args):
|
20 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
21 |
+
if torch.cuda.is_available():
|
22 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
23 |
+
torch.backends.cudnn.benchmark = True
|
24 |
+
torch.backends.cudnn.deterministic = False
|
25 |
+
torch.backends.cudnn.allow_tf32 = True
|
26 |
+
|
27 |
+
print(f"creating model: {args.model}")
|
28 |
+
model = CLIP(vision_model=args.model)
|
29 |
+
|
30 |
+
print(f"loading checkpoint from {args.ckpt_path}")
|
31 |
+
state_dict = torch.load(args.ckpt_path, map_location='cpu')
|
32 |
+
model.load_state_dict(state_dict, strict=True)
|
33 |
+
model.to(device)
|
34 |
+
|
35 |
+
def _convert_image_to_rgb(image):
|
36 |
+
return image.convert("RGB")
|
37 |
+
|
38 |
+
val_transform = transforms.Compose([
|
39 |
+
transforms.Resize(args.image_size, transforms.InterpolationMode.BICUBIC),
|
40 |
+
transforms.CenterCrop(args.image_size),
|
41 |
+
_convert_image_to_rgb,
|
42 |
+
transforms.ToTensor(),
|
43 |
+
transforms.Normalize(mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD)
|
44 |
+
])
|
45 |
+
|
46 |
+
val_dataset = datasets.ImageFolder(os.path.join(args.imagenet_path, 'val'), transform=val_transform)
|
47 |
+
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.workers)
|
48 |
+
|
49 |
+
model.eval()
|
50 |
+
classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, device)
|
51 |
+
top1, top5 = zero_shot_eval(model, classifier, val_loader, device)
|
52 |
+
print(f'ImageNet zeroshot top1: {top1:.4f}, top5: {top5:.4f}')
|
53 |
+
|
54 |
+
|
55 |
+
def zero_shot_classifier(model, classnames, templates, device):
|
56 |
+
tokenizer = tokenize
|
57 |
+
|
58 |
+
with torch.no_grad():
|
59 |
+
zeroshot_weights = []
|
60 |
+
for classname in tqdm(classnames):
|
61 |
+
texts = [template(classname) for template in templates] # format with class
|
62 |
+
texts = tokenizer(texts).to(device=device) # tokenize
|
63 |
+
with torch.cuda.amp.autocast():
|
64 |
+
class_embeddings = model.encode_text(texts)
|
65 |
+
class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
|
66 |
+
class_embedding /= class_embedding.norm()
|
67 |
+
zeroshot_weights.append(class_embedding)
|
68 |
+
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
|
69 |
+
return zeroshot_weights
|
70 |
+
|
71 |
+
def accuracy(output, target, topk=(1,)):
|
72 |
+
pred = output.topk(max(topk), 1, True, True)[1].t()
|
73 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
74 |
+
return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
|
75 |
+
|
76 |
+
def zero_shot_eval(model, classifier, dataloader, device):
|
77 |
+
top1, top5, n = 0., 0., 0.
|
78 |
+
with torch.no_grad():
|
79 |
+
for images, target in tqdm(dataloader, unit_scale=args.batch_size):
|
80 |
+
images = images.to(device=device)
|
81 |
+
target = target.to(device=device)
|
82 |
+
|
83 |
+
with torch.cuda.amp.autocast():
|
84 |
+
image_features = model.encode_image(images)
|
85 |
+
image_features = F.normalize(image_features, dim=-1)
|
86 |
+
logits = 100. * image_features @ classifier
|
87 |
+
|
88 |
+
# measure accuracy
|
89 |
+
acc1, acc5 = accuracy(logits, target, topk=(1, 5))
|
90 |
+
top1 += acc1
|
91 |
+
top5 += acc5
|
92 |
+
n += images.size(0)
|
93 |
+
|
94 |
+
top1 = (top1 / n)
|
95 |
+
top5 = (top5 / n)
|
96 |
+
return top1, top5
|
97 |
+
|
98 |
+
|
99 |
+
if __name__ == '__main__':
|
100 |
+
parser = argparse.ArgumentParser(description='ImageNet zero shot evaluations', add_help=False)
|
101 |
+
parser.add_argument('--imagenet-path', default='path/to/imagenet', type=str, help='path to imagenet dataset')
|
102 |
+
parser.add_argument('--ckpt-path', default='path/to/ckpt', type=str, help='path to checkpoint')
|
103 |
+
parser.add_argument('--batch-size', default=64, type=int, help='batch size')
|
104 |
+
parser.add_argument('--model', default='eva_base_p16', type=str, help='model')
|
105 |
+
parser.add_argument('--image-size', default=224, type=int, help='image size for evaluation')
|
106 |
+
parser.add_argument('--workers', default=8, type=int)
|
107 |
+
args = parser.parse_args()
|
108 |
+
main(args)
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tqdm
|
2 |
+
timm
|
3 |
+
torch
|
4 |
+
open_clip
|
5 |
+
torchvision
|
6 |
+
xformers
|