Update resampler.py
Browse files- resampler.py +9 -0
resampler.py
CHANGED
@@ -117,6 +117,15 @@ class Resampler(nn.Module):
|
|
117 |
self.max_size = [max(max_h, self.max_size[0]), max(max_w, self.max_size[1])]
|
118 |
self._set_2d_pos_cache(self.max_size, device)
|
119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
def _init_weights(self, m):
|
121 |
if isinstance(m, nn.Linear):
|
122 |
trunc_normal_(m.weight, std=0.02)
|
|
|
117 |
self.max_size = [max(max_h, self.max_size[0]), max(max_w, self.max_size[1])]
|
118 |
self._set_2d_pos_cache(self.max_size, device)
|
119 |
|
120 |
+
def _initialize_weights(self, module):
|
121 |
+
"""
|
122 |
+
Initialize the weights if they are not already initialized.
|
123 |
+
"""
|
124 |
+
if getattr(module, "_is_hf_initialized", False):
|
125 |
+
return
|
126 |
+
self._init_weights(module)
|
127 |
+
module._is_hf_initialized = True
|
128 |
+
|
129 |
def _init_weights(self, m):
|
130 |
if isinstance(m, nn.Linear):
|
131 |
trunc_normal_(m.weight, std=0.02)
|