farzadab commited on
Commit
4215e51
·
verified ·
1 Parent(s): c84f28d

Update ultravox_processing.py

Browse files
Files changed (1) hide show
  1. ultravox_processing.py +9 -9
ultravox_processing.py CHANGED
@@ -113,7 +113,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
113
  tokenizer.eos_token is not None
114
  ), "The tokenizer has no EOS token. Cannot recover."
115
  self.vocab = tokenizer.get_vocab()
116
- self.audio_replacement = tokenizer.eos_token
117
  if tokenizer.pad_token_id is None:
118
  tokenizer.pad_token_id = tokenizer.eos_token_id
119
 
@@ -188,7 +188,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
188
  )
189
  is_continuation_list.append(is_continuation)
190
 
191
- return {
192
  "audio_values": torch.stack(chunked_audio_values, dim=0),
193
  "audio_lens": torch.tensor(
194
  chunked_audio_lens, dtype=torch.int64, device=audio_values.device
@@ -199,12 +199,12 @@ class UltravoxProcessor(transformers.ProcessorMixin):
199
  "audio_batch_size": torch.tensor(
200
  [len(chunked_audio_values)], device=audio_values.device
201
  ),
202
- "audio_num_chunks": (
203
- torch.tensor(num_chunks, dtype=torch.int64, device=audio_values.device)
204
- if include_audio_num_chunks
205
- else None
206
- ),
207
  }
 
 
 
 
 
208
 
209
  def __call__(
210
  self,
@@ -327,7 +327,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
327
  split_input_ids = tokenized_parts["input_ids"]
328
  input_ids: List[int] = []
329
 
330
- audio_replacement_token_id = self.vocab[self.audio_replacement]
331
 
332
  for i, token_len in enumerate(data.get("audio_token_len", [])):
333
  if not audio_is_continuation[i]:
@@ -341,7 +341,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
341
 
342
  audio_token_start_idx.append(len(input_ids))
343
 
344
- input_ids.extend([audio_replacement_token_id] * token_len)
345
 
346
  # Include any tokens after the last audio.
347
  placeholder_index += 1
 
113
  tokenizer.eos_token is not None
114
  ), "The tokenizer has no EOS token. Cannot recover."
115
  self.vocab = tokenizer.get_vocab()
116
+ self.audio_token_replacement = tokenizer.eos_token
117
  if tokenizer.pad_token_id is None:
118
  tokenizer.pad_token_id = tokenizer.eos_token_id
119
 
 
188
  )
189
  is_continuation_list.append(is_continuation)
190
 
191
+ data = {
192
  "audio_values": torch.stack(chunked_audio_values, dim=0),
193
  "audio_lens": torch.tensor(
194
  chunked_audio_lens, dtype=torch.int64, device=audio_values.device
 
199
  "audio_batch_size": torch.tensor(
200
  [len(chunked_audio_values)], device=audio_values.device
201
  ),
 
 
 
 
 
202
  }
203
+ if include_audio_num_chunks:
204
+ data["audio_num_chunks"] = torch.tensor(
205
+ num_chunks, dtype=torch.int64, device=audio_values.device
206
+ )
207
+ return data
208
 
209
  def __call__(
210
  self,
 
327
  split_input_ids = tokenized_parts["input_ids"]
328
  input_ids: List[int] = []
329
 
330
+ audio_token_replacement_token_id = self.vocab[self.audio_token_replacement]
331
 
332
  for i, token_len in enumerate(data.get("audio_token_len", [])):
333
  if not audio_is_continuation[i]:
 
341
 
342
  audio_token_start_idx.append(len(input_ids))
343
 
344
+ input_ids.extend([audio_token_replacement_token_id] * token_len)
345
 
346
  # Include any tokens after the last audio.
347
  placeholder_index += 1