fix confidence when alg_temp > 0

#2
by Zhihui - opened
Files changed (1) hide show
  1. generation_utils.py +2 -2
generation_utils.py CHANGED
@@ -441,8 +441,8 @@ class DreamGenerationMixin:
441
  if alg_temp is None or alg_temp == 0:
442
  _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
443
  else:
444
- confidence = confidence / alg_temp
445
- confidence = F.softmax(confidence, dim=-1)
446
  transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
447
  x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
448
  x_[mask_index] = x0.clone()
 
441
  if alg_temp is None or alg_temp == 0:
442
  _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
443
  else:
444
+ full_confidence = full_confidence / alg_temp
445
+ full_confidence = F.softmax(full_confidence, dim=-1)
446
  transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
447
  x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
448
  x_[mask_index] = x0.clone()