fix conflict
Browse files- gcvit/models/gcvit.py +1 -24
gcvit/models/gcvit.py
CHANGED
|
@@ -2,25 +2,12 @@ import numpy as np
|
|
| 2 |
import tensorflow as tf
|
| 3 |
|
| 4 |
from ..layers import Stem, GCViTLevel, Identity
|
| 5 |
-
from ..layers import Stem, GCViTLevel, Identity
|
| 6 |
|
| 7 |
|
| 8 |
|
| 9 |
BASE_URL = 'https://github.com/awsaf49/gcvit-tf/releases/download'
|
| 10 |
TAG = 'v1.1.1'
|
| 11 |
NAME2CONFIG = {
|
| 12 |
-
'gcvit_xxtiny': {'window_size': (7, 7, 14, 7),
|
| 13 |
-
'dim': 64,
|
| 14 |
-
'depths': (2, 2, 6, 2),
|
| 15 |
-
'num_heads': (2, 4, 8, 16),
|
| 16 |
-
'mlp_ratio': 3.,
|
| 17 |
-
'path_drop': 0.2},
|
| 18 |
-
'gcvit_xtiny': {'window_size': (7, 7, 14, 7),
|
| 19 |
-
'dim': 64,
|
| 20 |
-
'depths': (3, 4, 6, 5),
|
| 21 |
-
'num_heads': (2, 4, 8, 16),
|
| 22 |
-
'mlp_ratio': 3.,
|
| 23 |
-
'path_drop': 0.2},
|
| 24 |
'gcvit_xxtiny': {'window_size': (7, 7, 14, 7),
|
| 25 |
'dim': 64,
|
| 26 |
'depths': (2, 2, 6, 2),
|
|
@@ -94,7 +81,6 @@ class GCViT(tf.keras.Model):
|
|
| 94 |
self.levels = []
|
| 95 |
for i in range(len(depths)):
|
| 96 |
path_drop = path_drops[sum(depths[:i]):sum(depths[:i + 1])].tolist()
|
| 97 |
-
level = GCViTLevel(depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], keep_dims=keep_dims[i],
|
| 98 |
level = GCViTLevel(depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], keep_dims=keep_dims[i],
|
| 99 |
downsample=(i < len(depths) - 1), mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 100 |
drop=drop_rate, attn_drop=attn_drop, path_drop=path_drop, layer_scale=layer_scale, resize_query=resize_query,
|
|
@@ -110,17 +96,14 @@ class GCViT(tf.keras.Model):
|
|
| 110 |
else:
|
| 111 |
raise ValueError(f'Expecting pooling to be one of None/avg/max. Found: {global_pool}')
|
| 112 |
self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act)
|
| 113 |
-
self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act)
|
| 114 |
|
| 115 |
-
|
| 116 |
def reset_classifier(self, num_classes, head_act, global_pool=None, in_channels=3):
|
| 117 |
self.num_classes = num_classes
|
| 118 |
if global_pool is not None:
|
| 119 |
self.global_pool = global_pool
|
| 120 |
self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act) if num_classes else Identity(name='head')
|
| 121 |
super().build((1, 224, 224, in_channels)) # for head we only need info from the input channel
|
| 122 |
-
self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act) if num_classes else Identity(name='head')
|
| 123 |
-
super().build((1, 224, 224, in_channels)) # for head we only need info from the input channel
|
| 124 |
|
| 125 |
def forward_features(self, inputs):
|
| 126 |
x = self.patch_embed(inputs)
|
|
@@ -137,7 +120,6 @@ class GCViT(tf.keras.Model):
|
|
| 137 |
x = self.pool(x)
|
| 138 |
if not pre_logits:
|
| 139 |
x = self.head(x)
|
| 140 |
-
x = self.head(x)
|
| 141 |
return x
|
| 142 |
|
| 143 |
def call(self, inputs, **kwargs):
|
|
@@ -153,8 +135,6 @@ class GCViT(tf.keras.Model):
|
|
| 153 |
def summary(self, input_shape=(224, 224, 3)):
|
| 154 |
return self.build_graph(input_shape).summary()
|
| 155 |
|
| 156 |
-
def summary(self, input_shape=(224, 224, 3)):
|
| 157 |
-
return self.build_graph(input_shape).summary()
|
| 158 |
|
| 159 |
# load standard models
|
| 160 |
def GCViTXXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
|
@@ -179,7 +159,6 @@ def GCViTXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **
|
|
| 179 |
model.load_weights(ckpt_path)
|
| 180 |
return model
|
| 181 |
|
| 182 |
-
def GCViTTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
| 183 |
def GCViTXXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
| 184 |
name = 'gcvit_xxtiny'
|
| 185 |
config = NAME2CONFIG[name]
|
|
@@ -215,7 +194,6 @@ def GCViTTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **k
|
|
| 215 |
model.load_weights(ckpt_path)
|
| 216 |
return model
|
| 217 |
|
| 218 |
-
def GCViTSmall(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
| 219 |
def GCViTSmall(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
| 220 |
name = 'gcvit_small'
|
| 221 |
config = NAME2CONFIG[name]
|
|
@@ -229,7 +207,6 @@ def GCViTSmall(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **
|
|
| 229 |
model.load_weights(ckpt_path)
|
| 230 |
return model
|
| 231 |
|
| 232 |
-
def GCViTBase(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
| 233 |
def GCViTBase(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
| 234 |
name = 'gcvit_base'
|
| 235 |
config = NAME2CONFIG[name]
|
|
|
|
| 2 |
import tensorflow as tf
|
| 3 |
|
| 4 |
from ..layers import Stem, GCViTLevel, Identity
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
|
| 8 |
BASE_URL = 'https://github.com/awsaf49/gcvit-tf/releases/download'
|
| 9 |
TAG = 'v1.1.1'
|
| 10 |
NAME2CONFIG = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
'gcvit_xxtiny': {'window_size': (7, 7, 14, 7),
|
| 12 |
'dim': 64,
|
| 13 |
'depths': (2, 2, 6, 2),
|
|
|
|
| 81 |
self.levels = []
|
| 82 |
for i in range(len(depths)):
|
| 83 |
path_drop = path_drops[sum(depths[:i]):sum(depths[:i + 1])].tolist()
|
|
|
|
| 84 |
level = GCViTLevel(depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], keep_dims=keep_dims[i],
|
| 85 |
downsample=(i < len(depths) - 1), mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 86 |
drop=drop_rate, attn_drop=attn_drop, path_drop=path_drop, layer_scale=layer_scale, resize_query=resize_query,
|
|
|
|
| 96 |
else:
|
| 97 |
raise ValueError(f'Expecting pooling to be one of None/avg/max. Found: {global_pool}')
|
| 98 |
self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act)
|
|
|
|
| 99 |
|
| 100 |
+
|
| 101 |
def reset_classifier(self, num_classes, head_act, global_pool=None, in_channels=3):
|
| 102 |
self.num_classes = num_classes
|
| 103 |
if global_pool is not None:
|
| 104 |
self.global_pool = global_pool
|
| 105 |
self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act) if num_classes else Identity(name='head')
|
| 106 |
super().build((1, 224, 224, in_channels)) # for head we only need info from the input channel
|
|
|
|
|
|
|
| 107 |
|
| 108 |
def forward_features(self, inputs):
|
| 109 |
x = self.patch_embed(inputs)
|
|
|
|
| 120 |
x = self.pool(x)
|
| 121 |
if not pre_logits:
|
| 122 |
x = self.head(x)
|
|
|
|
| 123 |
return x
|
| 124 |
|
| 125 |
def call(self, inputs, **kwargs):
|
|
|
|
| 135 |
def summary(self, input_shape=(224, 224, 3)):
|
| 136 |
return self.build_graph(input_shape).summary()
|
| 137 |
|
|
|
|
|
|
|
| 138 |
|
| 139 |
# load standard models
|
| 140 |
def GCViTXXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
|
|
|
| 159 |
model.load_weights(ckpt_path)
|
| 160 |
return model
|
| 161 |
|
|
|
|
| 162 |
def GCViTXXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
| 163 |
name = 'gcvit_xxtiny'
|
| 164 |
config = NAME2CONFIG[name]
|
|
|
|
| 194 |
model.load_weights(ckpt_path)
|
| 195 |
return model
|
| 196 |
|
|
|
|
| 197 |
def GCViTSmall(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
| 198 |
name = 'gcvit_small'
|
| 199 |
config = NAME2CONFIG[name]
|
|
|
|
| 207 |
model.load_weights(ckpt_path)
|
| 208 |
return model
|
| 209 |
|
|
|
|
| 210 |
def GCViTBase(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
| 211 |
name = 'gcvit_base'
|
| 212 |
config = NAME2CONFIG[name]
|