jiacheng-ye commited on
Commit
baec25c
·
verified ·
1 Parent(s): 0bacc53

Upload model

Browse files
Files changed (1) hide show
  1. generation_utils.py +21 -6
generation_utils.py CHANGED
@@ -302,8 +302,9 @@ class DreamGenerationMixin:
302
  **kwargs,
303
  ) -> Union[DreamModelOutput, torch.LongTensor]:
304
  # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
305
- tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
306
  generation_config = self._prepare_generation_config(generation_config, **kwargs)
 
 
307
 
308
  # 2. Define model inputs
309
  assert inputs is not None
@@ -355,6 +356,8 @@ class DreamGenerationMixin:
355
  input_ids,
356
  attention_mask=attention_mask,
357
  generation_config=generation_config,
 
 
358
  )
359
  return result
360
 
@@ -363,6 +366,8 @@ class DreamGenerationMixin:
363
  input_ids: torch.LongTensor,
364
  attention_mask: Optional[torch.LongTensor],
365
  generation_config: DreamGenerationConfig,
 
 
366
  ) -> Union[DreamModelOutput, torch.LongTensor]:
367
  # init values
368
  output_history = generation_config.output_history
@@ -398,11 +403,18 @@ class DreamGenerationMixin:
398
  attention_mask = "full"
399
 
400
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
 
 
 
401
  for i in range(steps):
402
  mask_index = (x == mask_token_id)
403
  logits = self(x, attention_mask, tok_idx).logits
404
  logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
405
- logits = logits[mask_index]
 
 
 
 
406
  t = timesteps[i]
407
  s = timesteps[i + 1]
408
 
@@ -410,15 +422,15 @@ class DreamGenerationMixin:
410
  p_transfer = 1 - s / t if i < steps - 1 else 1
411
  x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
412
  transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
413
- _, x0[transfer_index_t_s]= sample_tokens(logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
414
  x[mask_index] = x0.clone()
415
  else:
416
  if alg == 'maskgit_plus':
417
- confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k)
418
  elif alg == 'topk_margin':
419
- confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
420
  elif alg == 'entropy':
421
- confidence, x0 = sample_tokens(logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
422
  else:
423
  raise RuntimeError(f"Unknown alg: {alg}")
424
  num_mask_token = mask_index.sum()
@@ -434,6 +446,9 @@ class DreamGenerationMixin:
434
  x0_[transfer_index] = x0[transfer_index].clone()
435
  x[mask_index] = x0_
436
 
 
 
 
437
  if histories is not None:
438
  histories.append(x.clone())
439
 
 
302
  **kwargs,
303
  ) -> Union[DreamModelOutput, torch.LongTensor]:
304
  # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
 
305
  generation_config = self._prepare_generation_config(generation_config, **kwargs)
306
+ generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
307
+ generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
308
 
309
  # 2. Define model inputs
310
  assert inputs is not None
 
356
  input_ids,
357
  attention_mask=attention_mask,
358
  generation_config=generation_config,
359
+ generation_tokens_hook_func=generation_tokens_hook_func,
360
+ generation_logits_hook_func=generation_logits_hook_func
361
  )
362
  return result
363
 
 
366
  input_ids: torch.LongTensor,
367
  attention_mask: Optional[torch.LongTensor],
368
  generation_config: DreamGenerationConfig,
369
+ generation_tokens_hook_func,
370
+ generation_logits_hook_func
371
  ) -> Union[DreamModelOutput, torch.LongTensor]:
372
  # init values
373
  output_history = generation_config.output_history
 
403
  attention_mask = "full"
404
 
405
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
406
+
407
+ # this allows user-defined token control of the intermediate steps
408
+ x = generation_tokens_hook_func(None, x, None)
409
  for i in range(steps):
410
  mask_index = (x == mask_token_id)
411
  logits = self(x, attention_mask, tok_idx).logits
412
  logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
413
+
414
+ # this allows user-defined logits control of the intermediate steps
415
+ logits = generation_logits_hook_func(i, x, logits)
416
+
417
+ mask_logits = logits[mask_index]
418
  t = timesteps[i]
419
  s = timesteps[i + 1]
420
 
 
422
  p_transfer = 1 - s / t if i < steps - 1 else 1
423
  x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
424
  transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
425
+ _, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
426
  x[mask_index] = x0.clone()
427
  else:
428
  if alg == 'maskgit_plus':
429
+ confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
430
  elif alg == 'topk_margin':
431
+ confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
432
  elif alg == 'entropy':
433
+ confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
434
  else:
435
  raise RuntimeError(f"Unknown alg: {alg}")
436
  num_mask_token = mask_index.sum()
 
446
  x0_[transfer_index] = x0[transfer_index].clone()
447
  x[mask_index] = x0_
448
 
449
+ # this allows user-defined token control of the intermediate steps
450
+ x = generation_tokens_hook_func(i, x, logits)
451
+
452
  if histories is not None:
453
  histories.append(x.clone())
454