Zero-Shot Classification
jihao commited on
Commit
f73bf08
·
1 Parent(s): 9d021bd

update eval files

Browse files
eva_vit_model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .uta_clip import CLIP
eva_vit_model/eva_vit.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Adapted from https://github.com/microsoft/unilm/tree/master/beit
3
+ # --------------------------------------------------------
4
+ import math
5
+ import os
6
+ from functools import partial
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ try:
11
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
12
+ except:
13
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
14
+
15
+ from .transformer import PatchDropout, LayerNorm
16
+ from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
17
+
18
+ if os.getenv('ENV_TYPE') == 'deepspeed':
19
+ try:
20
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
21
+ except:
22
+ from torch.utils.checkpoint import checkpoint
23
+ else:
24
+ from torch.utils.checkpoint import checkpoint
25
+
26
+ try:
27
+ import xformers.ops as xops
28
+ except ImportError:
29
+ xops = None
30
+ print("Please 'pip install xformers'")
31
+
32
+
33
+ class DropPath(nn.Module):
34
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
35
+ """
36
+ def __init__(self, drop_prob=None):
37
+ super(DropPath, self).__init__()
38
+ self.drop_prob = drop_prob
39
+
40
+ def forward(self, x):
41
+ return drop_path(x, self.drop_prob, self.training)
42
+
43
+ def extra_repr(self) -> str:
44
+ return 'p={}'.format(self.drop_prob)
45
+
46
+
47
+ class Mlp(nn.Module):
48
+ def __init__(
49
+ self,
50
+ in_features,
51
+ hidden_features=None,
52
+ out_features=None,
53
+ act_layer=nn.GELU,
54
+ norm_layer=nn.LayerNorm,
55
+ drop=0.,
56
+ subln=False,
57
+
58
+ ):
59
+ super().__init__()
60
+ out_features = out_features or in_features
61
+ hidden_features = hidden_features or in_features
62
+ self.fc1 = nn.Linear(in_features, hidden_features)
63
+ self.act = act_layer()
64
+
65
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
66
+
67
+ self.fc2 = nn.Linear(hidden_features, out_features)
68
+ self.drop = nn.Dropout(drop)
69
+
70
+ def forward(self, x):
71
+ x = self.fc1(x)
72
+ x = self.act(x)
73
+ # x = self.drop(x)
74
+ # commit this for the orignal BERT implement
75
+ x = self.ffn_ln(x)
76
+
77
+ x = self.fc2(x)
78
+ x = self.drop(x)
79
+ return x
80
+
81
+ class SwiGLU(nn.Module):
82
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
83
+ norm_layer=nn.LayerNorm, subln=False):
84
+ super().__init__()
85
+ out_features = out_features or in_features
86
+ hidden_features = hidden_features or in_features
87
+
88
+ self.w1 = nn.Linear(in_features, hidden_features)
89
+ self.w2 = nn.Linear(in_features, hidden_features)
90
+
91
+ self.act = act_layer()
92
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
93
+ self.w3 = nn.Linear(hidden_features, out_features)
94
+
95
+ self.drop = nn.Dropout(drop)
96
+
97
+ def forward(self, x):
98
+ x1 = self.w1(x)
99
+ x2 = self.w2(x)
100
+ hidden = self.act(x1) * x2
101
+ x = self.ffn_ln(hidden)
102
+ x = self.w3(x)
103
+ x = self.drop(x)
104
+ return x
105
+
106
+ class Attention(nn.Module):
107
+ def __init__(
108
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
109
+ proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
110
+ super().__init__()
111
+ self.num_heads = num_heads
112
+ head_dim = dim // num_heads
113
+ if attn_head_dim is not None:
114
+ head_dim = attn_head_dim
115
+ all_head_dim = head_dim * self.num_heads
116
+ self.scale = qk_scale or head_dim ** -0.5
117
+
118
+ self.subln = subln
119
+ if self.subln:
120
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
121
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
122
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
123
+ else:
124
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
125
+
126
+ if qkv_bias:
127
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
128
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
129
+ else:
130
+ self.q_bias = None
131
+ self.v_bias = None
132
+
133
+ if window_size:
134
+ self.window_size = window_size
135
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
136
+ self.relative_position_bias_table = nn.Parameter(
137
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
138
+ # cls to token & token 2 cls & cls to cls
139
+
140
+ # get pair-wise relative position index for each token inside the window
141
+ coords_h = torch.arange(window_size[0])
142
+ coords_w = torch.arange(window_size[1])
143
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
144
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
145
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
146
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
147
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
148
+ relative_coords[:, :, 1] += window_size[1] - 1
149
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
150
+ relative_position_index = \
151
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
152
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
153
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
154
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
155
+ relative_position_index[0, 0] = self.num_relative_distance - 1
156
+
157
+ self.register_buffer("relative_position_index", relative_position_index)
158
+ else:
159
+ self.window_size = None
160
+ self.relative_position_bias_table = None
161
+ self.relative_position_index = None
162
+
163
+ self.attn_drop = nn.Dropout(attn_drop)
164
+ # self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
165
+ self.inner_attn_ln = nn.Identity()
166
+ # self.proj = nn.Linear(all_head_dim, all_head_dim)
167
+ self.proj = nn.Linear(all_head_dim, dim)
168
+ self.proj_drop = nn.Dropout(proj_drop)
169
+ self.xattn = xattn
170
+ self.xattn_drop = attn_drop
171
+
172
+ self.rope = rope
173
+
174
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
175
+ B, N, C = x.shape
176
+ if self.subln:
177
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
178
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
179
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
180
+
181
+ q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
182
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
183
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
184
+ else:
185
+
186
+ qkv_bias = None
187
+ if self.q_bias is not None:
188
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
189
+
190
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
191
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
192
+ q, k, v = qkv[0], qkv[1], qkv[2]
193
+
194
+ if self.rope:
195
+ # slightly fast impl
196
+ q_t = q[:, :, 1:, :]
197
+ ro_q_t = self.rope(q_t)
198
+ q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
199
+
200
+ k_t = k[:, :, 1:, :]
201
+ ro_k_t = self.rope(k_t)
202
+ k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
203
+
204
+ if self.xattn:
205
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
206
+ k = k.permute(0, 2, 1, 3)
207
+ v = v.permute(0, 2, 1, 3)
208
+
209
+ x = xops.memory_efficient_attention(
210
+ q, k, v,
211
+ p=self.xattn_drop,
212
+ scale=self.scale,
213
+ )
214
+ x = x.reshape(B, N, -1)
215
+ x = self.inner_attn_ln(x)
216
+ x = self.proj(x)
217
+ x = self.proj_drop(x)
218
+ else:
219
+ q = q * self.scale
220
+ attn = (q @ k.transpose(-2, -1))
221
+
222
+ if self.relative_position_bias_table is not None:
223
+ relative_position_bias = \
224
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
225
+ self.window_size[0] * self.window_size[1] + 1,
226
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
227
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
228
+ attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
229
+
230
+ if rel_pos_bias is not None:
231
+ attn = attn + rel_pos_bias.type_as(attn)
232
+
233
+ if attn_mask is not None:
234
+ attn_mask = attn_mask.bool()
235
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
236
+
237
+ attn = attn.softmax(dim=-1)
238
+ attn = self.attn_drop(attn)
239
+
240
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
241
+ x = self.inner_attn_ln(x)
242
+ x = self.proj(x)
243
+ x = self.proj_drop(x)
244
+ return x
245
+
246
+
247
+ class Block(nn.Module):
248
+
249
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
250
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
251
+ window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
252
+ subln=False, naiveswiglu=False):
253
+ super().__init__()
254
+ self.norm1 = norm_layer(dim)
255
+ self.attn = Attention(
256
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
257
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
258
+ xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
259
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
260
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
261
+ self.norm2 = norm_layer(dim)
262
+ mlp_hidden_dim = int(dim * mlp_ratio)
263
+
264
+ if naiveswiglu:
265
+ self.mlp = SwiGLU(
266
+ in_features=dim,
267
+ hidden_features=mlp_hidden_dim,
268
+ subln=subln,
269
+ norm_layer=norm_layer,
270
+ )
271
+ else:
272
+ self.mlp = Mlp(
273
+ in_features=dim,
274
+ hidden_features=mlp_hidden_dim,
275
+ act_layer=act_layer,
276
+ subln=subln,
277
+ drop=drop
278
+ )
279
+
280
+ if init_values is not None and init_values > 0:
281
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
282
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
283
+ else:
284
+ self.gamma_1, self.gamma_2 = None, None
285
+
286
+ self.postnorm = postnorm
287
+
288
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
289
+ if self.gamma_1 is None:
290
+ if self.postnorm:
291
+ x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
292
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
293
+ else:
294
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
295
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
296
+ else:
297
+ if self.postnorm:
298
+ x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
299
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
300
+ else:
301
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
302
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
303
+ return x
304
+
305
+
306
+ class PatchEmbed(nn.Module):
307
+ """ Image to Patch Embedding
308
+ """
309
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
310
+ super().__init__()
311
+ img_size = to_2tuple(img_size)
312
+ patch_size = to_2tuple(patch_size)
313
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
314
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
315
+ self.img_size = img_size
316
+ self.patch_size = patch_size
317
+ self.num_patches = num_patches
318
+
319
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
320
+
321
+ def forward(self, x, **kwargs):
322
+ B, C, H, W = x.shape
323
+ # FIXME look at relaxing size constraints
324
+ assert H == self.img_size[0] and W == self.img_size[1], \
325
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
326
+ x = self.proj(x).flatten(2).transpose(1, 2)
327
+ return x
328
+
329
+
330
+ class RelativePositionBias(nn.Module):
331
+
332
+ def __init__(self, window_size, num_heads):
333
+ super().__init__()
334
+ self.window_size = window_size
335
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
336
+ self.relative_position_bias_table = nn.Parameter(
337
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
338
+ # cls to token & token 2 cls & cls to cls
339
+
340
+ # get pair-wise relative position index for each token inside the window
341
+ coords_h = torch.arange(window_size[0])
342
+ coords_w = torch.arange(window_size[1])
343
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
344
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
345
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
346
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
347
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
348
+ relative_coords[:, :, 1] += window_size[1] - 1
349
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
350
+ relative_position_index = \
351
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
352
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
353
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
354
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
355
+ relative_position_index[0, 0] = self.num_relative_distance - 1
356
+
357
+ self.register_buffer("relative_position_index", relative_position_index)
358
+
359
+ def forward(self):
360
+ relative_position_bias = \
361
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
362
+ self.window_size[0] * self.window_size[1] + 1,
363
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
364
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
365
+
366
+
367
+ class EVAVisionTransformer(nn.Module):
368
+ """ Vision Transformer with support for patch or hybrid CNN input stage
369
+ """
370
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
371
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
372
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
373
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
374
+ use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
375
+ pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False, head_2mlp=False):
376
+ super().__init__()
377
+ self.image_size = img_size
378
+ self.num_classes = num_classes
379
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
380
+ self.head_2mlp = head_2mlp
381
+
382
+ self.patch_embed = PatchEmbed(
383
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
384
+ num_patches = self.patch_embed.num_patches
385
+
386
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
387
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
388
+ if use_abs_pos_emb:
389
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
390
+ else:
391
+ self.pos_embed = None
392
+ self.pos_drop = nn.Dropout(p=drop_rate)
393
+
394
+ if use_shared_rel_pos_bias:
395
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
396
+ else:
397
+ self.rel_pos_bias = None
398
+
399
+ if rope:
400
+ half_head_dim = embed_dim // num_heads // 2
401
+ hw_seq_len = img_size // patch_size
402
+ self.rope = VisionRotaryEmbeddingFast(
403
+ dim=half_head_dim,
404
+ pt_seq_len=pt_hw_seq_len,
405
+ ft_seq_len=hw_seq_len if intp_freq else None,
406
+ # patch_dropout=patch_dropout
407
+ )
408
+ else:
409
+ self.rope = None
410
+
411
+ self.naiveswiglu = naiveswiglu
412
+
413
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
414
+ self.use_rel_pos_bias = use_rel_pos_bias
415
+ self.blocks = nn.ModuleList([
416
+ Block(
417
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
418
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
419
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
420
+ xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
421
+ for i in range(depth)])
422
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
423
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
424
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
425
+
426
+ if self.pos_embed is not None:
427
+ trunc_normal_(self.pos_embed, std=.02)
428
+
429
+ trunc_normal_(self.cls_token, std=.02)
430
+ # trunc_normal_(self.mask_token, std=.02)
431
+
432
+ self.apply(self._init_weights)
433
+ self.fix_init_weight()
434
+
435
+ if isinstance(self.head, nn.Linear):
436
+ trunc_normal_(self.head.weight, std=.02)
437
+ self.head.weight.data.mul_(init_scale)
438
+ self.head.bias.data.mul_(init_scale)
439
+
440
+ if head_2mlp:
441
+ self.proj = nn.Linear(embed_dim, 512)
442
+ self.out_norm = norm_layer(512)
443
+ self.head_clip = nn.Linear(512, num_classes)
444
+ del self.head
445
+
446
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
447
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
448
+
449
+ self.grad_checkpointing = grad_checkpointing
450
+
451
+ def fix_init_weight(self):
452
+ def rescale(param, layer_id):
453
+ param.div_(math.sqrt(2.0 * layer_id))
454
+
455
+ for layer_id, layer in enumerate(self.blocks):
456
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
457
+ if self.naiveswiglu:
458
+ rescale(layer.mlp.w3.weight.data, layer_id + 1)
459
+ else:
460
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
461
+
462
+ def get_cast_dtype(self) -> torch.dtype:
463
+ return self.blocks[0].mlp.fc2.weight.dtype
464
+
465
+ def _init_weights(self, m):
466
+ if isinstance(m, nn.Linear):
467
+ trunc_normal_(m.weight, std=.02)
468
+ if m.bias is not None:
469
+ nn.init.constant_(m.bias, 0)
470
+ elif isinstance(m, nn.LayerNorm):
471
+ nn.init.constant_(m.bias, 0)
472
+ nn.init.constant_(m.weight, 1.0)
473
+
474
+ def get_num_layers(self):
475
+ return len(self.blocks)
476
+
477
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
478
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
479
+ for param in self.parameters():
480
+ param.requires_grad = False
481
+
482
+ @torch.jit.ignore
483
+ def set_grad_checkpointing(self, enable=True):
484
+ self.grad_checkpointing = enable
485
+
486
+ @torch.jit.ignore
487
+ def no_weight_decay(self):
488
+ return {'pos_embed', 'cls_token'}
489
+
490
+ def get_classifier(self):
491
+ return self.head
492
+
493
+ def reset_classifier(self, num_classes, global_pool=''):
494
+ self.num_classes = num_classes
495
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
496
+
497
+ def forward_features(self, x, return_all_features=False):
498
+
499
+ x = self.patch_embed(x)
500
+ batch_size, seq_len, _ = x.size()
501
+
502
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
503
+ x = torch.cat((cls_tokens, x), dim=1)
504
+ if self.pos_embed is not None:
505
+ x = x + self.pos_embed
506
+ x = self.pos_drop(x)
507
+
508
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
509
+ if os.getenv('RoPE') == '1':
510
+ if self.training and not isinstance(self.patch_dropout, nn.Identity):
511
+ x, patch_indices_keep = self.patch_dropout(x)
512
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
513
+ else:
514
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
515
+ x = self.patch_dropout(x)
516
+ else:
517
+ x = self.patch_dropout(x)
518
+
519
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
520
+ for blk in self.blocks:
521
+ if self.grad_checkpointing:
522
+ x = checkpoint(blk, x, (rel_pos_bias,))
523
+ else:
524
+ x = blk(x, rel_pos_bias=rel_pos_bias)
525
+
526
+ if not return_all_features:
527
+ x = self.norm(x)
528
+ if self.fc_norm is not None:
529
+ return self.fc_norm(x.mean(1))
530
+ else:
531
+ return x[:, 0]
532
+ return x
533
+
534
+ def forward(self, x, return_all_features=False):
535
+ if return_all_features:
536
+ return self.forward_features(x, return_all_features)
537
+ x = self.forward_features(x)
538
+ if self.head_2mlp:
539
+ x = self.proj(x)
540
+ x = self.out_norm(x)
541
+ x = self.head_clip(x)
542
+ else:
543
+ x = self.head(x)
544
+ return x
545
+
546
+
547
+ def eva_base_p16():
548
+ model = EVAVisionTransformer(
549
+ depth=12, embed_dim=768, num_heads=12, mlp_ratio=2.6667, num_classes=1024,
550
+ xattn=True, rope=True, intp_freq=True, naiveswiglu=True,
551
+ subln=True, use_mean_pooling=False, qkv_bias=True,
552
+ norm_layer=partial(LayerNorm, eps=1e-6)
553
+ )
554
+ return model
555
+
556
+ def eva_large_p14_336():
557
+ model = EVAVisionTransformer(
558
+ img_size=336,
559
+ depth=24, embed_dim=1024, num_heads=16, mlp_ratio=2.6667,patch_size=14, num_classes=1024,
560
+ xattn=True, rope=True, intp_freq=True, naiveswiglu=True,
561
+ subln=True, use_mean_pooling=False, qkv_bias=True,
562
+ norm_layer=partial(LayerNorm, eps=1e-6)
563
+ )
564
+ return model
565
+
566
+
567
+ def eva_giant_p14_336():
568
+ model = EVAVisionTransformer(
569
+ img_size=336,
570
+ depth=40, embed_dim=1408, num_heads=16, mlp_ratio=2.909133333333333,patch_size=14, num_classes=1024,
571
+ xattn=True, rope=True, intp_freq=True, naiveswiglu=True,
572
+ subln=True, use_mean_pooling=False, qkv_bias=True,
573
+ norm_layer=partial(LayerNorm, eps=1e-6)
574
+ )
575
+ return model
eva_vit_model/rope.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ import torch
3
+ from torch import nn
4
+ from einops import rearrange, repeat
5
+ import logging
6
+
7
+ def broadcat(tensors, dim = -1):
8
+ num_tensors = len(tensors)
9
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
10
+ assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
11
+ shape_len = list(shape_lens)[0]
12
+ dim = (dim + shape_len) if dim < 0 else dim
13
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
14
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
15
+ assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
16
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
17
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
18
+ expanded_dims.insert(dim, (dim, dims[dim]))
19
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
20
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
21
+ return torch.cat(tensors, dim = dim)
22
+
23
+ def rotate_half(x):
24
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
25
+ x1, x2 = x.unbind(dim = -1)
26
+ x = torch.stack((-x2, x1), dim = -1)
27
+ return rearrange(x, '... d r -> ... (d r)')
28
+
29
+
30
+ class VisionRotaryEmbedding(nn.Module):
31
+ def __init__(
32
+ self,
33
+ dim,
34
+ pt_seq_len,
35
+ ft_seq_len=None,
36
+ custom_freqs = None,
37
+ freqs_for = 'lang',
38
+ theta = 10000,
39
+ max_freq = 10,
40
+ num_freqs = 1,
41
+ ):
42
+ super().__init__()
43
+ if custom_freqs:
44
+ freqs = custom_freqs
45
+ elif freqs_for == 'lang':
46
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
47
+ elif freqs_for == 'pixel':
48
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
49
+ elif freqs_for == 'constant':
50
+ freqs = torch.ones(num_freqs).float()
51
+ else:
52
+ raise ValueError(f'unknown modality {freqs_for}')
53
+
54
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
55
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
56
+
57
+ freqs_h = torch.einsum('..., f -> ... f', t, freqs)
58
+ freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
59
+
60
+ freqs_w = torch.einsum('..., f -> ... f', t, freqs)
61
+ freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
62
+
63
+ freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
64
+
65
+ self.register_buffer("freqs_cos", freqs.cos())
66
+ self.register_buffer("freqs_sin", freqs.sin())
67
+
68
+ logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
69
+
70
+ def forward(self, t, start_index = 0):
71
+ rot_dim = self.freqs_cos.shape[-1]
72
+ end_index = start_index + rot_dim
73
+ assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
74
+ t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
75
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
76
+
77
+ return torch.cat((t_left, t, t_right), dim = -1)
78
+
79
+ class VisionRotaryEmbeddingFast(nn.Module):
80
+ def __init__(
81
+ self,
82
+ dim,
83
+ pt_seq_len,
84
+ ft_seq_len=None,
85
+ custom_freqs = None,
86
+ freqs_for = 'lang',
87
+ theta = 10000,
88
+ max_freq = 10,
89
+ num_freqs = 1,
90
+ patch_dropout = 0.
91
+ ):
92
+ super().__init__()
93
+ if custom_freqs:
94
+ freqs = custom_freqs
95
+ elif freqs_for == 'lang':
96
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
97
+ elif freqs_for == 'pixel':
98
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
99
+ elif freqs_for == 'constant':
100
+ freqs = torch.ones(num_freqs).float()
101
+ else:
102
+ raise ValueError(f'unknown modality {freqs_for}')
103
+
104
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
105
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
106
+
107
+ freqs = torch.einsum('..., f -> ... f', t, freqs)
108
+ freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
109
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
110
+
111
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
112
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
113
+
114
+ self.patch_dropout = patch_dropout
115
+
116
+ self.register_buffer("freqs_cos", freqs_cos)
117
+ self.register_buffer("freqs_sin", freqs_sin)
118
+
119
+ logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
120
+
121
+ def forward(self, t, patch_indices_keep=None):
122
+ if patch_indices_keep is not None:
123
+ batch = t.size()[0]
124
+ batch_indices = torch.arange(batch)
125
+ batch_indices = batch_indices[..., None]
126
+
127
+ freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
128
+ freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
129
+
130
+ freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
131
+ freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
132
+ freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
133
+ freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
134
+
135
+ return t * freqs_cos + rotate_half(t) * freqs_sin
136
+
137
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
eva_vit_model/transformer.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from collections import OrderedDict
4
+ import math
5
+ from typing import Callable, Optional, Sequence
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+
11
+ if os.getenv('ENV_TYPE') == 'deepspeed':
12
+ try:
13
+ import deepspeed
14
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
15
+ except:
16
+ print("Please 'pip install deepspeed'")
17
+ deepspeed = None
18
+ from torch.utils.checkpoint import checkpoint
19
+ else:
20
+ from torch.utils.checkpoint import checkpoint
21
+
22
+ try:
23
+ import xformers.ops as xops
24
+ except ImportError:
25
+ xops = None
26
+ print("Please 'pip install xformers'")
27
+
28
+ class LayerNormFp32(nn.LayerNorm):
29
+ """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
30
+ def __init__(self, *args, **kwargs):
31
+ super().__init__(*args, **kwargs)
32
+
33
+ def forward(self, x: torch.Tensor):
34
+ output = F.layer_norm(
35
+ x.float(),
36
+ self.normalized_shape,
37
+ self.weight.float() if self.weight is not None else None,
38
+ self.bias.float() if self.bias is not None else None,
39
+ self.eps,
40
+ )
41
+ return output.type_as(x)
42
+
43
+
44
+ class LayerNorm(nn.LayerNorm):
45
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
46
+
47
+ def forward(self, x: torch.Tensor):
48
+ orig_type = x.dtype
49
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
50
+ return x.to(orig_type)
51
+
52
+ class QuickGELU(nn.Module):
53
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
54
+ def forward(self, x: torch.Tensor):
55
+ return x * torch.sigmoid(1.702 * x)
56
+
57
+
58
+ class LayerScale(nn.Module):
59
+ def __init__(self, dim, init_values=1e-5, inplace=False):
60
+ super().__init__()
61
+ self.inplace = inplace
62
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
63
+
64
+ def forward(self, x):
65
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
66
+
67
+ class PatchDropout(nn.Module):
68
+ """
69
+ https://arxiv.org/abs/2212.00794
70
+ """
71
+
72
+ def __init__(self, prob, exclude_first_token=True):
73
+ super().__init__()
74
+ assert 0 <= prob < 1.
75
+ self.prob = prob
76
+ self.exclude_first_token = exclude_first_token # exclude CLS token
77
+ logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
78
+
79
+ def forward(self, x):
80
+ if not self.training or self.prob == 0.:
81
+ return x
82
+
83
+ if self.exclude_first_token:
84
+ cls_tokens, x = x[:, :1], x[:, 1:]
85
+ else:
86
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
87
+
88
+ batch = x.size()[0]
89
+ num_tokens = x.size()[1]
90
+
91
+ batch_indices = torch.arange(batch)
92
+ batch_indices = batch_indices[..., None]
93
+
94
+ keep_prob = 1 - self.prob
95
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
96
+
97
+ rand = torch.randn(batch, num_tokens)
98
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
99
+
100
+ x = x[batch_indices, patch_indices_keep]
101
+
102
+ if self.exclude_first_token:
103
+ x = torch.cat((cls_tokens, x), dim=1)
104
+
105
+ if self.training and os.getenv('RoPE') == '1':
106
+ return x, patch_indices_keep
107
+
108
+ return x
109
+
110
+
111
+ def _in_projection_packed(
112
+ q: torch.Tensor,
113
+ k: torch.Tensor,
114
+ v: torch.Tensor,
115
+ w: torch.Tensor,
116
+ b: Optional[torch.Tensor] = None,
117
+ ):
118
+ """
119
+ https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726
120
+ """
121
+ E = q.size(-1)
122
+ if k is v:
123
+ if q is k:
124
+ # self-attention
125
+ return F.linear(q, w, b).chunk(3, dim=-1)
126
+ else:
127
+ # encoder-decoder attention
128
+ w_q, w_kv = w.split([E, E * 2])
129
+ if b is None:
130
+ b_q = b_kv = None
131
+ else:
132
+ b_q, b_kv = b.split([E, E * 2])
133
+ return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
134
+ else:
135
+ w_q, w_k, w_v = w.chunk(3)
136
+ if b is None:
137
+ b_q = b_k = b_v = None
138
+ else:
139
+ b_q, b_k, b_v = b.chunk(3)
140
+ return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
141
+
142
+ class Attention(nn.Module):
143
+ def __init__(
144
+ self,
145
+ dim,
146
+ num_heads=8,
147
+ qkv_bias=True,
148
+ scaled_cosine=False,
149
+ scale_heads=False,
150
+ logit_scale_max=math.log(1. / 0.01),
151
+ attn_drop=0.,
152
+ proj_drop=0.,
153
+ xattn=False,
154
+ rope=False
155
+ ):
156
+ super().__init__()
157
+ self.scaled_cosine = scaled_cosine
158
+ self.scale_heads = scale_heads
159
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
160
+ self.num_heads = num_heads
161
+ self.head_dim = dim // num_heads
162
+ self.scale = self.head_dim ** -0.5
163
+ self.logit_scale_max = logit_scale_max
164
+
165
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
166
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
167
+ if qkv_bias:
168
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
169
+ else:
170
+ self.in_proj_bias = None
171
+
172
+ if self.scaled_cosine:
173
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
174
+ else:
175
+ self.logit_scale = None
176
+ self.attn_drop = nn.Dropout(attn_drop)
177
+ if self.scale_heads:
178
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
179
+ else:
180
+ self.head_scale = None
181
+ self.out_proj = nn.Linear(dim, dim)
182
+ self.out_drop = nn.Dropout(proj_drop)
183
+ self.xattn = xattn
184
+ self.xattn_drop = attn_drop
185
+ self.rope = rope
186
+
187
+ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
188
+ L, N, C = x.shape
189
+ q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
190
+ if self.xattn:
191
+ q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
192
+ k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
193
+ v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
194
+
195
+ x = xops.memory_efficient_attention(
196
+ q, k, v,
197
+ p=self.xattn_drop,
198
+ scale=self.scale if self.logit_scale is None else None,
199
+ attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None,
200
+ )
201
+ else:
202
+ q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
203
+ k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
204
+ v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
205
+
206
+ if self.logit_scale is not None:
207
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
208
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
209
+ attn = attn.view(N, self.num_heads, L, L) * logit_scale
210
+ attn = attn.view(-1, L, L)
211
+ else:
212
+ q = q * self.scale
213
+ attn = torch.bmm(q, k.transpose(-1, -2))
214
+
215
+ if attn_mask is not None:
216
+ if attn_mask.dtype == torch.bool:
217
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
218
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
219
+ attn_mask = new_attn_mask
220
+ attn += attn_mask
221
+
222
+ attn = attn.softmax(dim=-1)
223
+ attn = self.attn_drop(attn)
224
+
225
+ x = torch.bmm(attn, v)
226
+
227
+ if self.head_scale is not None:
228
+ x = x.view(N, self.num_heads, L, C) * self.head_scale
229
+ x = x.view(-1, L, C)
230
+ x = x.transpose(0, 1).reshape(L, N, C)
231
+ x = self.out_proj(x)
232
+ x = self.out_drop(x)
233
+ return x
234
+
235
+ class CustomAttention(nn.Module):
236
+ def __init__(
237
+ self,
238
+ dim,
239
+ num_heads=8,
240
+ qkv_bias=True,
241
+ scaled_cosine=True,
242
+ scale_heads=False,
243
+ logit_scale_max=math.log(1. / 0.01),
244
+ attn_drop=0.,
245
+ proj_drop=0.,
246
+ xattn=False
247
+ ):
248
+ super().__init__()
249
+ self.scaled_cosine = scaled_cosine
250
+ self.scale_heads = scale_heads
251
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
252
+ self.num_heads = num_heads
253
+ self.head_dim = dim // num_heads
254
+ self.scale = self.head_dim ** -0.5
255
+ self.logit_scale_max = logit_scale_max
256
+
257
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
258
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
259
+ if qkv_bias:
260
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
261
+ else:
262
+ self.in_proj_bias = None
263
+
264
+ if self.scaled_cosine:
265
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
266
+ else:
267
+ self.logit_scale = None
268
+ self.attn_drop = nn.Dropout(attn_drop)
269
+ if self.scale_heads:
270
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
271
+ else:
272
+ self.head_scale = None
273
+ self.out_proj = nn.Linear(dim, dim)
274
+ self.out_drop = nn.Dropout(proj_drop)
275
+ self.xattn = xattn
276
+ self.xattn_drop = attn_drop
277
+
278
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
279
+ q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
280
+ N_q, B_q, C_q = q.shape
281
+ N_k, B_k, C_k = k.shape
282
+ N_v, B_v, C_v = v.shape
283
+ if self.xattn:
284
+ # B, N, C -> B, N, num_heads, C
285
+ q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1)
286
+ k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1)
287
+ v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1)
288
+
289
+ x = xops.memory_efficient_attention(
290
+ q, k, v,
291
+ p=self.xattn_drop,
292
+ scale=self.scale if self.logit_scale is None else None,
293
+ attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None
294
+ )
295
+ else:
296
+ # B*H, L, C
297
+ q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1)
298
+ k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1)
299
+ v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1)
300
+
301
+ if self.logit_scale is not None:
302
+ # B*H, N_q, N_k
303
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
304
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
305
+ attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale
306
+ attn = attn.view(-1, N_q, N_k)
307
+ else:
308
+ q = q * self.scale
309
+ attn = torch.bmm(q, k.transpose(-1, -2))
310
+
311
+ if attn_mask is not None:
312
+ if attn_mask.dtype == torch.bool:
313
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
314
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
315
+ attn_mask = new_attn_mask
316
+ attn += attn_mask
317
+
318
+ attn = attn.softmax(dim=-1)
319
+ attn = self.attn_drop(attn)
320
+
321
+ x = torch.bmm(attn, v)
322
+
323
+ if self.head_scale is not None:
324
+ x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale
325
+ x = x.view(-1, N_q, C_q)
326
+ x = x.transpose(0, 1).reshape(N_q, B_q, C_q)
327
+ x = self.out_proj(x)
328
+ x = self.out_drop(x)
329
+ return x
330
+
331
+ class CustomResidualAttentionBlock(nn.Module):
332
+ def __init__(
333
+ self,
334
+ d_model: int,
335
+ n_head: int,
336
+ mlp_ratio: float = 4.0,
337
+ ls_init_value: float = None,
338
+ act_layer: Callable = nn.GELU,
339
+ norm_layer: Callable = LayerNorm,
340
+ scale_cosine_attn: bool = False,
341
+ scale_heads: bool = False,
342
+ scale_attn: bool = False,
343
+ scale_fc: bool = False,
344
+ cross_attn: bool = False,
345
+ xattn: bool = False,
346
+ ):
347
+ super().__init__()
348
+
349
+ self.ln_1 = norm_layer(d_model)
350
+ self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1
351
+ self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1
352
+ self.attn = CustomAttention(
353
+ d_model, n_head,
354
+ qkv_bias=True,
355
+ attn_drop=0.,
356
+ proj_drop=0.,
357
+ scaled_cosine=scale_cosine_attn,
358
+ scale_heads=scale_heads,
359
+ xattn=xattn
360
+ )
361
+
362
+ self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
363
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
364
+
365
+ self.ln_2 = norm_layer(d_model)
366
+ mlp_width = int(d_model * mlp_ratio)
367
+ self.mlp = nn.Sequential(OrderedDict([
368
+ ("c_fc", nn.Linear(d_model, mlp_width)),
369
+ ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
370
+ ("gelu", act_layer()),
371
+ ("c_proj", nn.Linear(mlp_width, d_model))
372
+ ]))
373
+
374
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
375
+
376
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
377
+ q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask)))
378
+ q = q + self.ls_2(self.mlp(self.ln_2(q)))
379
+ return q
380
+
381
+ class CustomTransformer(nn.Module):
382
+ def __init__(
383
+ self,
384
+ width: int,
385
+ layers: int,
386
+ heads: int,
387
+ mlp_ratio: float = 4.0,
388
+ ls_init_value: float = None,
389
+ act_layer: Callable = nn.GELU,
390
+ norm_layer: Callable = LayerNorm,
391
+ scale_cosine_attn: bool = True,
392
+ scale_heads: bool = False,
393
+ scale_attn: bool = False,
394
+ scale_fc: bool = False,
395
+ cross_attn: bool = False,
396
+ xattn: bool = False,
397
+ ):
398
+ super().__init__()
399
+ self.width = width
400
+ self.layers = layers
401
+ self.grad_checkpointing = False
402
+ self.xattn = xattn
403
+
404
+ self.resblocks = nn.ModuleList([
405
+ CustomResidualAttentionBlock(
406
+ width,
407
+ heads,
408
+ mlp_ratio,
409
+ ls_init_value=ls_init_value,
410
+ act_layer=act_layer,
411
+ norm_layer=norm_layer,
412
+ scale_cosine_attn=scale_cosine_attn,
413
+ scale_heads=scale_heads,
414
+ scale_attn=scale_attn,
415
+ scale_fc=scale_fc,
416
+ cross_attn=cross_attn,
417
+ xattn=xattn)
418
+ for _ in range(layers)
419
+ ])
420
+
421
+ def get_cast_dtype(self) -> torch.dtype:
422
+ return self.resblocks[0].mlp.c_fc.weight.dtype
423
+
424
+ def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None):
425
+ if k is None and v is None:
426
+ k = v = q
427
+ for r in self.resblocks:
428
+ if self.grad_checkpointing and not torch.jit.is_scripting():
429
+ q = checkpoint(r, q, k, v, attn_mask)
430
+ else:
431
+ q = r(q, k, v, attn_mask=attn_mask)
432
+ return q
433
+
434
+
435
+ class ResidualAttentionBlock(nn.Module):
436
+ def __init__(
437
+ self,
438
+ d_model: int,
439
+ n_head: int,
440
+ mlp_ratio: float = 4.0,
441
+ ls_init_value: float = None,
442
+ act_layer: Callable = nn.GELU,
443
+ norm_layer: Callable = LayerNorm,
444
+ xattn: bool = False,
445
+ ):
446
+ super().__init__()
447
+
448
+ self.ln_1 = norm_layer(d_model)
449
+ if xattn:
450
+ self.attn = Attention(d_model, n_head, xattn=True)
451
+ else:
452
+ self.attn = nn.MultiheadAttention(d_model, n_head)
453
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
454
+
455
+ self.ln_2 = norm_layer(d_model)
456
+ mlp_width = int(d_model * mlp_ratio)
457
+ self.mlp = nn.Sequential(OrderedDict([
458
+ ("c_fc", nn.Linear(d_model, mlp_width)),
459
+ ("gelu", act_layer()),
460
+ ("c_proj", nn.Linear(mlp_width, d_model))
461
+ ]))
462
+
463
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
464
+ self.xattn = xattn
465
+
466
+ def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
467
+ attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
468
+ if self.xattn:
469
+ return self.attn(x, attn_mask=attn_mask)
470
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
471
+
472
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
473
+ x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask))
474
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
475
+ return x
476
+
477
+ class Transformer(nn.Module):
478
+ def __init__(
479
+ self,
480
+ width: int,
481
+ layers: int,
482
+ heads: int,
483
+ mlp_ratio: float = 4.0,
484
+ ls_init_value: float = None,
485
+ act_layer: Callable = nn.GELU,
486
+ norm_layer: Callable = LayerNorm,
487
+ xattn: bool = False,
488
+ ):
489
+ super().__init__()
490
+ self.width = width
491
+ self.layers = layers
492
+ self.grad_checkpointing = False
493
+
494
+ self.resblocks = nn.ModuleList([
495
+ ResidualAttentionBlock(
496
+ width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn)
497
+ for _ in range(layers)
498
+ ])
499
+
500
+ def get_cast_dtype(self) -> torch.dtype:
501
+ return self.resblocks[0].mlp.c_fc.weight.dtype
502
+
503
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
504
+ for r in self.resblocks:
505
+ if self.grad_checkpointing and not torch.jit.is_scripting():
506
+ x = checkpoint(r, x, attn_mask)
507
+ else:
508
+ x = r(x, attn_mask=attn_mask)
509
+ return x
510
+
511
+
512
+ class TextTransformer(nn.Module):
513
+ def __init__(
514
+ self,
515
+ context_length: int = 77,
516
+ vocab_size: int = 49408,
517
+ width: int = 512,
518
+ heads: int = 8,
519
+ layers: int = 12,
520
+ ls_init_value: float = None,
521
+ output_dim: int = 512,
522
+ act_layer: Callable = nn.GELU,
523
+ norm_layer: Callable = LayerNorm,
524
+ xattn: bool= False,
525
+ attn_mask: bool = True
526
+ ):
527
+ super().__init__()
528
+ self.context_length = context_length
529
+ self.vocab_size = vocab_size
530
+ self.width = width
531
+ self.output_dim = output_dim
532
+
533
+ self.token_embedding = nn.Embedding(vocab_size, width)
534
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
535
+ self.transformer = Transformer(
536
+ width=width,
537
+ layers=layers,
538
+ heads=heads,
539
+ ls_init_value=ls_init_value,
540
+ act_layer=act_layer,
541
+ norm_layer=norm_layer,
542
+ xattn=xattn
543
+ )
544
+
545
+ self.xattn = xattn
546
+ self.ln_final = norm_layer(width)
547
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
548
+
549
+ if attn_mask:
550
+ self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
551
+ else:
552
+ self.attn_mask = None
553
+
554
+ self.init_parameters()
555
+
556
+ def init_parameters(self):
557
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
558
+ nn.init.normal_(self.positional_embedding, std=0.01)
559
+
560
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
561
+ attn_std = self.transformer.width ** -0.5
562
+ fc_std = (2 * self.transformer.width) ** -0.5
563
+ for block in self.transformer.resblocks:
564
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
565
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
566
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
567
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
568
+
569
+ if self.text_projection is not None:
570
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
571
+
572
+ def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
573
+ if not unlocked_layers: # full freezing
574
+ for n, p in self.named_parameters():
575
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
576
+ else:
577
+ raise ValueError("Not support partial freeze")
578
+
579
+ @torch.jit.ignore
580
+ def set_grad_checkpointing(self, enable=True):
581
+ self.transformer.grad_checkpointing = enable
582
+
583
+ @torch.jit.ignore
584
+ def no_weight_decay(self):
585
+ # return {'positional_embedding', 'token_embedding'}
586
+ return {'positional_embedding'}
587
+
588
+ def get_num_layers(self):
589
+ return self.transformer.layers
590
+
591
+ def build_attention_mask(self):
592
+ # lazily create causal attention mask, with full attention between the vision tokens
593
+ # pytorch uses additive attention mask; fill with -inf
594
+ mask = torch.empty(self.context_length, self.context_length)
595
+ mask.fill_(float("-inf"))
596
+ mask.triu_(1) # zero out the lower diagonal
597
+ return mask
598
+
599
+ def forward(self, text, return_all_features: bool=False):
600
+ cast_dtype = self.transformer.get_cast_dtype()
601
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
602
+
603
+ x = x + self.positional_embedding.to(cast_dtype)
604
+ x = x.permute(1, 0, 2) # NLD -> LND
605
+ x = self.transformer(x, attn_mask=self.attn_mask)
606
+ # x = self.transformer(x) # no attention mask is applied
607
+ x = x.permute(1, 0, 2) # LND -> NLD
608
+ x = self.ln_final(x)
609
+
610
+ if not return_all_features:
611
+ # x.shape = [batch_size, n_ctx, transformer.width]
612
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
613
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
614
+ return x
615
+
616
+
617
+ def text_transformer():
618
+ model = TextTransformer(
619
+ width=1024,
620
+ output_dim=1024,
621
+ heads=16,
622
+ layers=24,
623
+ xattn=True
624
+ )
625
+ return model
eva_vit_model/uta_clip.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from . import eva_vit
8
+ from .transformer import text_transformer
9
+
10
+ class CLIP(nn.Module):
11
+ def __init__(
12
+ self,
13
+ vision_model: str = 'eva_base_p16',
14
+ ):
15
+ super().__init__()
16
+ self.visual = eva_vit.__dict__[vision_model]()
17
+ self.text = text_transformer()
18
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
19
+
20
+ def encode_image(self, image, normalize: bool = False):
21
+ features = self.visual(image)
22
+ return F.normalize(features, dim=-1) if normalize else features
23
+
24
+ def encode_text(self, text, normalize: bool = False):
25
+ features = self.text(text)
26
+ return F.normalize(features, dim=-1) if normalize else features
27
+
28
+ def forward(self, image, text):
29
+ image_features = self.encode_image(image, normalize=True)
30
+ text_features = self.encode_text(text, normalize=True)
31
+ return image_features, text_features, self.logit_scale.exp()
imagenet_zeroshot_data.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ imagenet_classnames = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray",
4
+ "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco",
5
+ "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper",
6
+ "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander",
7
+ "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog",
8
+ "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin",
9
+ "box turtle", "banded gecko", "green iguana", "Carolina anole",
10
+ "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard",
11
+ "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile",
12
+ "American alligator", "triceratops", "worm snake", "ring-necked snake",
13
+ "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake",
14
+ "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra",
15
+ "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake",
16
+ "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider",
17
+ "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider",
18
+ "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl",
19
+ "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet",
20
+ "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck",
21
+ "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby",
22
+ "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch",
23
+ "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab",
24
+ "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab",
25
+ "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron",
26
+ "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot",
27
+ "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher",
28
+ "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion",
29
+ "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel",
30
+ "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle",
31
+ "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound",
32
+ "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound",
33
+ "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound",
34
+ "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier",
35
+ "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier",
36
+ "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier",
37
+ "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier",
38
+ "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer",
39
+ "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier",
40
+ "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier",
41
+ "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever",
42
+ "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla",
43
+ "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel",
44
+ "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel",
45
+ "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard",
46
+ "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie",
47
+ "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann",
48
+ "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog",
49
+ "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff",
50
+ "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky",
51
+ "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog",
52
+ "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon",
53
+ "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle",
54
+ "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf",
55
+ "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox",
56
+ "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat",
57
+ "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger",
58
+ "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose",
59
+ "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle",
60
+ "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper",
61
+ "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper",
62
+ "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly",
63
+ "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly",
64
+ "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit",
65
+ "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse",
66
+ "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison",
67
+ "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)",
68
+ "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat",
69
+ "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan",
70
+ "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque",
71
+ "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin",
72
+ "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey",
73
+ "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda",
74
+ "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish",
75
+ "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown",
76
+ "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
77
+ "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle",
78
+ "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo",
79
+ "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel",
80
+ "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel",
81
+ "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)",
82
+ "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini",
83
+ "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet",
84
+ "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra",
85
+ "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest",
86
+ "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe",
87
+ "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton",
88
+ "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran",
89
+ "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw",
90
+ "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking",
91
+ "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker",
92
+ "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard",
93
+ "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot",
94
+ "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed",
95
+ "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer",
96
+ "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table",
97
+ "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig",
98
+ "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar",
99
+ "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder",
100
+ "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute",
101
+ "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed",
102
+ "freight car", "French horn", "frying pan", "fur coat", "garbage truck",
103
+ "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola",
104
+ "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine",
105
+ "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer",
106
+ "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet",
107
+ "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar",
108
+ "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep",
109
+ "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat",
110
+ "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library",
111
+ "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion",
112
+ "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag",
113
+ "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask",
114
+ "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone",
115
+ "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile",
116
+ "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor",
117
+ "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa",
118
+ "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail",
119
+ "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina",
120
+ "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart",
121
+ "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush",
122
+ "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench",
123
+ "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case",
124
+ "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube",
125
+ "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball",
126
+ "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag",
127
+ "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho",
128
+ "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug",
129
+ "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill",
130
+ "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel",
131
+ "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator",
132
+ "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser",
133
+ "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal",
134
+ "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard",
135
+ "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store",
136
+ "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap",
137
+ "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door",
138
+ "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
139
+ "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater",
140
+ "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight",
141
+ "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf",
142
+ "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa",
143
+ "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge",
144
+ "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe",
145
+ "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball",
146
+ "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof",
147
+ "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store",
148
+ "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod",
149
+ "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard",
150
+ "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling",
151
+ "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball",
152
+ "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink",
153
+ "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle",
154
+ "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing",
155
+ "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website",
156
+ "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu",
157
+ "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette",
158
+ "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli",
159
+ "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber",
160
+ "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange",
161
+ "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate",
162
+ "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito",
163
+ "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef",
164
+ "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player",
165
+ "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
166
+ "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom",
167
+ "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
168
+
169
+
170
+
171
+
172
+
173
+ openai_imagenet_template = [
174
+ lambda c: f'a bad photo of a {c}.',
175
+ lambda c: f'a photo of many {c}.',
176
+ lambda c: f'a sculpture of a {c}.',
177
+ lambda c: f'a photo of the hard to see {c}.',
178
+ lambda c: f'a low resolution photo of the {c}.',
179
+ lambda c: f'a rendering of a {c}.',
180
+ lambda c: f'graffiti of a {c}.',
181
+ lambda c: f'a bad photo of the {c}.',
182
+ lambda c: f'a cropped photo of the {c}.',
183
+ lambda c: f'a tattoo of a {c}.',
184
+ lambda c: f'the embroidered {c}.',
185
+ lambda c: f'a photo of a hard to see {c}.',
186
+ lambda c: f'a bright photo of a {c}.',
187
+ lambda c: f'a photo of a clean {c}.',
188
+ lambda c: f'a photo of a dirty {c}.',
189
+ lambda c: f'a dark photo of the {c}.',
190
+ lambda c: f'a drawing of a {c}.',
191
+ lambda c: f'a photo of my {c}.',
192
+ lambda c: f'the plastic {c}.',
193
+ lambda c: f'a photo of the cool {c}.',
194
+ lambda c: f'a close-up photo of a {c}.',
195
+ lambda c: f'a black and white photo of the {c}.',
196
+ lambda c: f'a painting of the {c}.',
197
+ lambda c: f'a painting of a {c}.',
198
+ lambda c: f'a pixelated photo of the {c}.',
199
+ lambda c: f'a sculpture of the {c}.',
200
+ lambda c: f'a bright photo of the {c}.',
201
+ lambda c: f'a cropped photo of a {c}.',
202
+ lambda c: f'a plastic {c}.',
203
+ lambda c: f'a photo of the dirty {c}.',
204
+ lambda c: f'a jpeg corrupted photo of a {c}.',
205
+ lambda c: f'a blurry photo of the {c}.',
206
+ lambda c: f'a photo of the {c}.',
207
+ lambda c: f'a good photo of the {c}.',
208
+ lambda c: f'a rendering of the {c}.',
209
+ lambda c: f'a {c} in a video game.',
210
+ lambda c: f'a photo of one {c}.',
211
+ lambda c: f'a doodle of a {c}.',
212
+ lambda c: f'a close-up photo of the {c}.',
213
+ lambda c: f'a photo of a {c}.',
214
+ lambda c: f'the origami {c}.',
215
+ lambda c: f'the {c} in a video game.',
216
+ lambda c: f'a sketch of a {c}.',
217
+ lambda c: f'a doodle of the {c}.',
218
+ lambda c: f'a origami {c}.',
219
+ lambda c: f'a low resolution photo of a {c}.',
220
+ lambda c: f'the toy {c}.',
221
+ lambda c: f'a rendition of the {c}.',
222
+ lambda c: f'a photo of the clean {c}.',
223
+ lambda c: f'a photo of a large {c}.',
224
+ lambda c: f'a rendition of a {c}.',
225
+ lambda c: f'a photo of a nice {c}.',
226
+ lambda c: f'a photo of a weird {c}.',
227
+ lambda c: f'a blurry photo of a {c}.',
228
+ lambda c: f'a cartoon {c}.',
229
+ lambda c: f'art of a {c}.',
230
+ lambda c: f'a sketch of the {c}.',
231
+ lambda c: f'a embroidered {c}.',
232
+ lambda c: f'a pixelated photo of a {c}.',
233
+ lambda c: f'itap of the {c}.',
234
+ lambda c: f'a jpeg corrupted photo of the {c}.',
235
+ lambda c: f'a good photo of a {c}.',
236
+ lambda c: f'a plushie {c}.',
237
+ lambda c: f'a photo of the nice {c}.',
238
+ lambda c: f'a photo of the small {c}.',
239
+ lambda c: f'a photo of the weird {c}.',
240
+ lambda c: f'the cartoon {c}.',
241
+ lambda c: f'art of the {c}.',
242
+ lambda c: f'a drawing of the {c}.',
243
+ lambda c: f'a photo of the large {c}.',
244
+ lambda c: f'a black and white photo of a {c}.',
245
+ lambda c: f'the plushie {c}.',
246
+ lambda c: f'a dark photo of a {c}.',
247
+ lambda c: f'itap of a {c}.',
248
+ lambda c: f'graffiti of the {c}.',
249
+ lambda c: f'a toy {c}.',
250
+ lambda c: f'itap of my {c}.',
251
+ lambda c: f'a photo of a cool {c}.',
252
+ lambda c: f'a photo of a small {c}.',
253
+ lambda c: f'a tattoo of the {c}.',
254
+ ]
imagenet_zeroshot_eval.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+ import argparse
4
+
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+ import torchvision.transforms as transforms
8
+ import torchvision.datasets as datasets
9
+ import torch.nn.functional as F
10
+
11
+ from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
12
+
13
+ import eva_vit_model
14
+ from eva_vit_model import CLIP
15
+ from open_clip.tokenizer import tokenize
16
+ from imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template
17
+
18
+
19
+ def main(args):
20
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
+ if torch.cuda.is_available():
22
+ torch.backends.cuda.matmul.allow_tf32 = True
23
+ torch.backends.cudnn.benchmark = True
24
+ torch.backends.cudnn.deterministic = False
25
+ torch.backends.cudnn.allow_tf32 = True
26
+
27
+ print(f"creating model: {args.model}")
28
+ model = CLIP(vision_model=args.model)
29
+
30
+ print(f"loading checkpoint from {args.ckpt_path}")
31
+ state_dict = torch.load(args.ckpt_path, map_location='cpu')
32
+ model.load_state_dict(state_dict, strict=True)
33
+ model.to(device)
34
+
35
+ def _convert_image_to_rgb(image):
36
+ return image.convert("RGB")
37
+
38
+ val_transform = transforms.Compose([
39
+ transforms.Resize(args.image_size, transforms.InterpolationMode.BICUBIC),
40
+ transforms.CenterCrop(args.image_size),
41
+ _convert_image_to_rgb,
42
+ transforms.ToTensor(),
43
+ transforms.Normalize(mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD)
44
+ ])
45
+
46
+ val_dataset = datasets.ImageFolder(os.path.join(args.imagenet_path, 'val'), transform=val_transform)
47
+ val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.workers)
48
+
49
+ model.eval()
50
+ classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, device)
51
+ top1, top5 = zero_shot_eval(model, classifier, val_loader, device)
52
+ print(f'ImageNet zeroshot top1: {top1:.4f}, top5: {top5:.4f}')
53
+
54
+
55
+ def zero_shot_classifier(model, classnames, templates, device):
56
+ tokenizer = tokenize
57
+
58
+ with torch.no_grad():
59
+ zeroshot_weights = []
60
+ for classname in tqdm(classnames):
61
+ texts = [template(classname) for template in templates] # format with class
62
+ texts = tokenizer(texts).to(device=device) # tokenize
63
+ with torch.cuda.amp.autocast():
64
+ class_embeddings = model.encode_text(texts)
65
+ class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
66
+ class_embedding /= class_embedding.norm()
67
+ zeroshot_weights.append(class_embedding)
68
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
69
+ return zeroshot_weights
70
+
71
+ def accuracy(output, target, topk=(1,)):
72
+ pred = output.topk(max(topk), 1, True, True)[1].t()
73
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
74
+ return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
75
+
76
+ def zero_shot_eval(model, classifier, dataloader, device):
77
+ top1, top5, n = 0., 0., 0.
78
+ with torch.no_grad():
79
+ for images, target in tqdm(dataloader, unit_scale=args.batch_size):
80
+ images = images.to(device=device)
81
+ target = target.to(device=device)
82
+
83
+ with torch.cuda.amp.autocast():
84
+ image_features = model.encode_image(images)
85
+ image_features = F.normalize(image_features, dim=-1)
86
+ logits = 100. * image_features @ classifier
87
+
88
+ # measure accuracy
89
+ acc1, acc5 = accuracy(logits, target, topk=(1, 5))
90
+ top1 += acc1
91
+ top5 += acc5
92
+ n += images.size(0)
93
+
94
+ top1 = (top1 / n)
95
+ top5 = (top5 / n)
96
+ return top1, top5
97
+
98
+
99
+ if __name__ == '__main__':
100
+ parser = argparse.ArgumentParser(description='ImageNet zero shot evaluations', add_help=False)
101
+ parser.add_argument('--imagenet-path', default='path/to/imagenet', type=str, help='path to imagenet dataset')
102
+ parser.add_argument('--ckpt-path', default='path/to/ckpt', type=str, help='path to checkpoint')
103
+ parser.add_argument('--batch-size', default=64, type=int, help='batch size')
104
+ parser.add_argument('--model', default='eva_base_p16', type=str, help='model')
105
+ parser.add_argument('--image-size', default=224, type=int, help='image size for evaluation')
106
+ parser.add_argument('--workers', default=8, type=int)
107
+ args = parser.parse_args()
108
+ main(args)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ tqdm
2
+ timm
3
+ torch
4
+ open_clip
5
+ torchvision
6
+ xformers