Update ultravox_processing.py
Browse files- ultravox_processing.py +9 -9
ultravox_processing.py
CHANGED
@@ -113,7 +113,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
113 |
tokenizer.eos_token is not None
|
114 |
), "The tokenizer has no EOS token. Cannot recover."
|
115 |
self.vocab = tokenizer.get_vocab()
|
116 |
-
self.
|
117 |
if tokenizer.pad_token_id is None:
|
118 |
tokenizer.pad_token_id = tokenizer.eos_token_id
|
119 |
|
@@ -188,7 +188,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
188 |
)
|
189 |
is_continuation_list.append(is_continuation)
|
190 |
|
191 |
-
|
192 |
"audio_values": torch.stack(chunked_audio_values, dim=0),
|
193 |
"audio_lens": torch.tensor(
|
194 |
chunked_audio_lens, dtype=torch.int64, device=audio_values.device
|
@@ -199,12 +199,12 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
199 |
"audio_batch_size": torch.tensor(
|
200 |
[len(chunked_audio_values)], device=audio_values.device
|
201 |
),
|
202 |
-
"audio_num_chunks": (
|
203 |
-
torch.tensor(num_chunks, dtype=torch.int64, device=audio_values.device)
|
204 |
-
if include_audio_num_chunks
|
205 |
-
else None
|
206 |
-
),
|
207 |
}
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
def __call__(
|
210 |
self,
|
@@ -327,7 +327,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
327 |
split_input_ids = tokenized_parts["input_ids"]
|
328 |
input_ids: List[int] = []
|
329 |
|
330 |
-
|
331 |
|
332 |
for i, token_len in enumerate(data.get("audio_token_len", [])):
|
333 |
if not audio_is_continuation[i]:
|
@@ -341,7 +341,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
341 |
|
342 |
audio_token_start_idx.append(len(input_ids))
|
343 |
|
344 |
-
input_ids.extend([
|
345 |
|
346 |
# Include any tokens after the last audio.
|
347 |
placeholder_index += 1
|
|
|
113 |
tokenizer.eos_token is not None
|
114 |
), "The tokenizer has no EOS token. Cannot recover."
|
115 |
self.vocab = tokenizer.get_vocab()
|
116 |
+
self.audio_token_replacement = tokenizer.eos_token
|
117 |
if tokenizer.pad_token_id is None:
|
118 |
tokenizer.pad_token_id = tokenizer.eos_token_id
|
119 |
|
|
|
188 |
)
|
189 |
is_continuation_list.append(is_continuation)
|
190 |
|
191 |
+
data = {
|
192 |
"audio_values": torch.stack(chunked_audio_values, dim=0),
|
193 |
"audio_lens": torch.tensor(
|
194 |
chunked_audio_lens, dtype=torch.int64, device=audio_values.device
|
|
|
199 |
"audio_batch_size": torch.tensor(
|
200 |
[len(chunked_audio_values)], device=audio_values.device
|
201 |
),
|
|
|
|
|
|
|
|
|
|
|
202 |
}
|
203 |
+
if include_audio_num_chunks:
|
204 |
+
data["audio_num_chunks"] = torch.tensor(
|
205 |
+
num_chunks, dtype=torch.int64, device=audio_values.device
|
206 |
+
)
|
207 |
+
return data
|
208 |
|
209 |
def __call__(
|
210 |
self,
|
|
|
327 |
split_input_ids = tokenized_parts["input_ids"]
|
328 |
input_ids: List[int] = []
|
329 |
|
330 |
+
audio_token_replacement_token_id = self.vocab[self.audio_token_replacement]
|
331 |
|
332 |
for i, token_len in enumerate(data.get("audio_token_len", [])):
|
333 |
if not audio_is_continuation[i]:
|
|
|
341 |
|
342 |
audio_token_start_idx.append(len(input_ids))
|
343 |
|
344 |
+
input_ids.extend([audio_token_replacement_token_id] * token_len)
|
345 |
|
346 |
# Include any tokens after the last audio.
|
347 |
placeholder_index += 1
|