gpt-omni commited on
Commit
6eacc63
1 Parent(s): 008f4a1
Files changed (1) hide show
  1. utils/snac_utils.py +2 -2
utils/snac_utils.py CHANGED
@@ -21,8 +21,8 @@ def layershift(input_id, layer, stride=4160, shift=152000):
21
  return input_id + shift + layer * stride
22
 
23
 
24
- def generate_audio_data(snac_tokens, snacmodel):
25
- audio = reconstruct_tensors(snac_tokens)
26
  with torch.inference_mode():
27
  audio_hat = snacmodel.decode(audio)
28
  audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
 
21
  return input_id + shift + layer * stride
22
 
23
 
24
+ def generate_audio_data(snac_tokens, snacmodel, device=None):
25
+ audio = reconstruct_tensors(snac_tokens, device)
26
  with torch.inference_mode():
27
  audio_hat = snacmodel.decode(audio)
28
  audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0