wuzhiying2023 commited on
Commit
cb7fc74
1 Parent(s): 57c398d

fix NormHead eval bug

Browse files
Files changed (1) hide show
  1. modeling_baichuan.py +1 -0
modeling_baichuan.py CHANGED
@@ -502,6 +502,7 @@ class NormHead(nn.Module):
502
  def forward(self, hidden_states):
503
  if self.training:
504
  norm_weight = nn.functional.normalize(self.weight)
 
505
  elif self.first_flag:
506
  self.first_flag = False
507
  self.weight = nn.Parameter(nn.functional.normalize(self.weight))
 
502
  def forward(self, hidden_states):
503
  if self.training:
504
  norm_weight = nn.functional.normalize(self.weight)
505
+ self.first_flag = True
506
  elif self.first_flag:
507
  self.first_flag = False
508
  self.weight = nn.Parameter(nn.functional.normalize(self.weight))