Upload model
Browse files- generation_utils.py +21 -6
generation_utils.py
CHANGED
@@ -302,8 +302,9 @@ class DreamGenerationMixin:
|
|
302 |
**kwargs,
|
303 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
304 |
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
305 |
-
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
|
306 |
generation_config = self._prepare_generation_config(generation_config, **kwargs)
|
|
|
|
|
307 |
|
308 |
# 2. Define model inputs
|
309 |
assert inputs is not None
|
@@ -355,6 +356,8 @@ class DreamGenerationMixin:
|
|
355 |
input_ids,
|
356 |
attention_mask=attention_mask,
|
357 |
generation_config=generation_config,
|
|
|
|
|
358 |
)
|
359 |
return result
|
360 |
|
@@ -363,6 +366,8 @@ class DreamGenerationMixin:
|
|
363 |
input_ids: torch.LongTensor,
|
364 |
attention_mask: Optional[torch.LongTensor],
|
365 |
generation_config: DreamGenerationConfig,
|
|
|
|
|
366 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
367 |
# init values
|
368 |
output_history = generation_config.output_history
|
@@ -398,11 +403,18 @@ class DreamGenerationMixin:
|
|
398 |
attention_mask = "full"
|
399 |
|
400 |
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
|
|
|
|
|
|
401 |
for i in range(steps):
|
402 |
mask_index = (x == mask_token_id)
|
403 |
logits = self(x, attention_mask, tok_idx).logits
|
404 |
logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
|
405 |
-
|
|
|
|
|
|
|
|
|
406 |
t = timesteps[i]
|
407 |
s = timesteps[i + 1]
|
408 |
|
@@ -410,15 +422,15 @@ class DreamGenerationMixin:
|
|
410 |
p_transfer = 1 - s / t if i < steps - 1 else 1
|
411 |
x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
|
412 |
transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
|
413 |
-
_, x0[transfer_index_t_s]= sample_tokens(
|
414 |
x[mask_index] = x0.clone()
|
415 |
else:
|
416 |
if alg == 'maskgit_plus':
|
417 |
-
confidence, x0 = sample_tokens(
|
418 |
elif alg == 'topk_margin':
|
419 |
-
confidence, x0 = sample_tokens(
|
420 |
elif alg == 'entropy':
|
421 |
-
confidence, x0 = sample_tokens(
|
422 |
else:
|
423 |
raise RuntimeError(f"Unknown alg: {alg}")
|
424 |
num_mask_token = mask_index.sum()
|
@@ -434,6 +446,9 @@ class DreamGenerationMixin:
|
|
434 |
x0_[transfer_index] = x0[transfer_index].clone()
|
435 |
x[mask_index] = x0_
|
436 |
|
|
|
|
|
|
|
437 |
if histories is not None:
|
438 |
histories.append(x.clone())
|
439 |
|
|
|
302 |
**kwargs,
|
303 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
304 |
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
|
|
305 |
generation_config = self._prepare_generation_config(generation_config, **kwargs)
|
306 |
+
generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
|
307 |
+
generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
|
308 |
|
309 |
# 2. Define model inputs
|
310 |
assert inputs is not None
|
|
|
356 |
input_ids,
|
357 |
attention_mask=attention_mask,
|
358 |
generation_config=generation_config,
|
359 |
+
generation_tokens_hook_func=generation_tokens_hook_func,
|
360 |
+
generation_logits_hook_func=generation_logits_hook_func
|
361 |
)
|
362 |
return result
|
363 |
|
|
|
366 |
input_ids: torch.LongTensor,
|
367 |
attention_mask: Optional[torch.LongTensor],
|
368 |
generation_config: DreamGenerationConfig,
|
369 |
+
generation_tokens_hook_func,
|
370 |
+
generation_logits_hook_func
|
371 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
372 |
# init values
|
373 |
output_history = generation_config.output_history
|
|
|
403 |
attention_mask = "full"
|
404 |
|
405 |
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
406 |
+
|
407 |
+
# this allows user-defined token control of the intermediate steps
|
408 |
+
x = generation_tokens_hook_func(None, x, None)
|
409 |
for i in range(steps):
|
410 |
mask_index = (x == mask_token_id)
|
411 |
logits = self(x, attention_mask, tok_idx).logits
|
412 |
logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
|
413 |
+
|
414 |
+
# this allows user-defined logits control of the intermediate steps
|
415 |
+
logits = generation_logits_hook_func(i, x, logits)
|
416 |
+
|
417 |
+
mask_logits = logits[mask_index]
|
418 |
t = timesteps[i]
|
419 |
s = timesteps[i + 1]
|
420 |
|
|
|
422 |
p_transfer = 1 - s / t if i < steps - 1 else 1
|
423 |
x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
|
424 |
transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
|
425 |
+
_, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
|
426 |
x[mask_index] = x0.clone()
|
427 |
else:
|
428 |
if alg == 'maskgit_plus':
|
429 |
+
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
|
430 |
elif alg == 'topk_margin':
|
431 |
+
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
|
432 |
elif alg == 'entropy':
|
433 |
+
confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
|
434 |
else:
|
435 |
raise RuntimeError(f"Unknown alg: {alg}")
|
436 |
num_mask_token = mask_index.sum()
|
|
|
446 |
x0_[transfer_index] = x0[transfer_index].clone()
|
447 |
x[mask_index] = x0_
|
448 |
|
449 |
+
# this allows user-defined token control of the intermediate steps
|
450 |
+
x = generation_tokens_hook_func(i, x, logits)
|
451 |
+
|
452 |
if histories is not None:
|
453 |
histories.append(x.clone())
|
454 |
|