Update ultravox_processing.py
Browse files- ultravox_processing.py +29 -1
ultravox_processing.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
-
from typing import Optional, Union
|
2 |
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
import transformers
|
6 |
|
|
|
|
|
7 |
|
8 |
class UltravoxProcessor(transformers.ProcessorMixin):
|
9 |
"""
|
@@ -56,6 +58,29 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
56 |
), "The tokenizer has no EOS token. Cannot recover."
|
57 |
super().__init__(audio_processor=audio_processor, tokenizer=tokenizer)
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
def __call__(
|
60 |
self,
|
61 |
text: Optional[str] = None,
|
@@ -175,3 +200,6 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
175 |
tokenizer_input_names = self.tokenizer.model_input_names
|
176 |
audio_processor_input_names = self.audio_processor.model_input_names
|
177 |
return list(set(tokenizer_input_names + audio_processor_input_names))
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union, Dict, Any
|
2 |
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
import transformers
|
6 |
|
7 |
+
from .ultravox_config import UltravoxConfig
|
8 |
+
|
9 |
|
10 |
class UltravoxProcessor(transformers.ProcessorMixin):
|
11 |
"""
|
|
|
58 |
), "The tokenizer has no EOS token. Cannot recover."
|
59 |
super().__init__(audio_processor=audio_processor, tokenizer=tokenizer)
|
60 |
|
61 |
+
@classmethod
|
62 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
63 |
+
config: UltravoxConfig = transformers.AutoConfig.from_pretrained(
|
64 |
+
pretrained_model_name_or_path, **kwargs
|
65 |
+
)
|
66 |
+
audio_processor = transformers.AutoProcessor.from_pretrained(
|
67 |
+
config.audio_model_id
|
68 |
+
or config.audio_config._name_or_path
|
69 |
+
or "facebook/wav2vec2-base-960h"
|
70 |
+
)
|
71 |
+
|
72 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
73 |
+
pretrained_model_name_or_path, **kwargs
|
74 |
+
)
|
75 |
+
tokenizer.padding_side = "left"
|
76 |
+
tokenizer.pad_token = tokenizer.eos_token
|
77 |
+
|
78 |
+
return cls(
|
79 |
+
audio_processor=audio_processor,
|
80 |
+
tokenizer=tokenizer,
|
81 |
+
stack_factor=config.stack_factor,
|
82 |
+
)
|
83 |
+
|
84 |
def __call__(
|
85 |
self,
|
86 |
text: Optional[str] = None,
|
|
|
200 |
tokenizer_input_names = self.tokenizer.model_input_names
|
201 |
audio_processor_input_names = self.audio_processor.model_input_names
|
202 |
return list(set(tokenizer_input_names + audio_processor_input_names))
|
203 |
+
|
204 |
+
|
205 |
+
transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor)
|