fix confidence when alg_temp > 0
#2
by
Zhihui
- opened
- generation_utils.py +2 -2
generation_utils.py
CHANGED
@@ -441,8 +441,8 @@ class DreamGenerationMixin:
|
|
441 |
if alg_temp is None or alg_temp == 0:
|
442 |
_, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
|
443 |
else:
|
444 |
-
|
445 |
-
|
446 |
transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
|
447 |
x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
|
448 |
x_[mask_index] = x0.clone()
|
|
|
441 |
if alg_temp is None or alg_temp == 0:
|
442 |
_, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
|
443 |
else:
|
444 |
+
full_confidence = full_confidence / alg_temp
|
445 |
+
full_confidence = F.softmax(full_confidence, dim=-1)
|
446 |
transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
|
447 |
x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
|
448 |
x_[mask_index] = x0.clone()
|