katuni4ka commited on
Commit
6a4422e
·
verified ·
1 Parent(s): a5601a0

Update resampler.py

Browse files
Files changed (1) hide show
  1. resampler.py +9 -0
resampler.py CHANGED
@@ -117,6 +117,15 @@ class Resampler(nn.Module):
117
  self.max_size = [max(max_h, self.max_size[0]), max(max_w, self.max_size[1])]
118
  self._set_2d_pos_cache(self.max_size, device)
119
 
 
 
 
 
 
 
 
 
 
120
  def _init_weights(self, m):
121
  if isinstance(m, nn.Linear):
122
  trunc_normal_(m.weight, std=0.02)
 
117
  self.max_size = [max(max_h, self.max_size[0]), max(max_w, self.max_size[1])]
118
  self._set_2d_pos_cache(self.max_size, device)
119
 
120
+ def _initialize_weights(self, module):
121
+ """
122
+ Initialize the weights if they are not already initialized.
123
+ """
124
+ if getattr(module, "_is_hf_initialized", False):
125
+ return
126
+ self._init_weights(module)
127
+ module._is_hf_initialized = True
128
+
129
  def _init_weights(self, m):
130
  if isinstance(m, nn.Linear):
131
  trunc_normal_(m.weight, std=0.02)