Upload model
Browse files- config.json +1 -1
- generation_utils.py +10 -7
config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "Dream-org/Dream-v0-
|
3 |
"architectures": [
|
4 |
"DreamModel"
|
5 |
],
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "Dream-org/Dream-7B-instruct-v0-preview",
|
3 |
"architectures": [
|
4 |
"DreamModel"
|
5 |
],
|
generation_utils.py
CHANGED
@@ -433,18 +433,21 @@ class DreamGenerationMixin:
|
|
433 |
confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
|
434 |
else:
|
435 |
raise RuntimeError(f"Unknown alg: {alg}")
|
436 |
-
num_mask_token = mask_index.sum()
|
437 |
-
number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else num_mask_token
|
|
|
|
|
438 |
if number_transfer_tokens > 0:
|
439 |
if alg_temp is None or alg_temp == 0:
|
440 |
-
_, transfer_index = torch.topk(
|
441 |
else:
|
442 |
confidence = confidence / alg_temp
|
443 |
confidence = F.softmax(confidence, dim=-1)
|
444 |
-
transfer_index = torch.multinomial(
|
445 |
-
|
446 |
-
|
447 |
-
x
|
|
|
448 |
|
449 |
# this allows user-defined token control of the intermediate steps
|
450 |
x = generation_tokens_hook_func(i, x, logits)
|
|
|
433 |
confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
|
434 |
else:
|
435 |
raise RuntimeError(f"Unknown alg: {alg}")
|
436 |
+
num_mask_token = mask_index.sum() / mask_index.shape[0]
|
437 |
+
number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
|
438 |
+
full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
|
439 |
+
full_confidence[mask_index] = confidence
|
440 |
if number_transfer_tokens > 0:
|
441 |
if alg_temp is None or alg_temp == 0:
|
442 |
+
_, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
|
443 |
else:
|
444 |
confidence = confidence / alg_temp
|
445 |
confidence = F.softmax(confidence, dim=-1)
|
446 |
+
transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
|
447 |
+
x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
|
448 |
+
x_[mask_index] = x0.clone()
|
449 |
+
row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
|
450 |
+
x[row_indices,transfer_index] = x_[row_indices,transfer_index]
|
451 |
|
452 |
# this allows user-defined token control of the intermediate steps
|
453 |
x = generation_tokens_hook_func(i, x, logits)
|