Update gigaam_transformers.py
Browse files- gigaam_transformers.py +4 -4
gigaam_transformers.py
CHANGED
@@ -276,9 +276,9 @@ class GigaAMFeatureExtractor(SequenceFeatureExtractor):
|
|
276 |
return BatchFeature({"input_features": input_features, "input_lengths": input_lengths}, tensor_type="pt")
|
277 |
|
278 |
|
279 |
-
class
|
280 |
"""
|
281 |
-
Char tokenizer for GigaAM
|
282 |
"""
|
283 |
def __init__(
|
284 |
self,
|
@@ -303,7 +303,7 @@ class GigaAMCTCTokenizer(Wav2Vec2CTCTokenizer):
|
|
303 |
|
304 |
class GigaAMProcessor(Wav2Vec2Processor):
|
305 |
feature_extractor_class = "GigaAMFeatureExtractor"
|
306 |
-
tokenizer_class = "
|
307 |
|
308 |
def __init__(self, feature_extractor, tokenizer):
|
309 |
# super().__init__(feature_extractor, tokenizer)
|
@@ -315,7 +315,7 @@ class GigaAMProcessor(Wav2Vec2Processor):
|
|
315 |
@classmethod
|
316 |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
317 |
feature_extractor = GigaAMFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
318 |
-
tokenizer =
|
319 |
|
320 |
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
321 |
|
|
|
276 |
return BatchFeature({"input_features": input_features, "input_lengths": input_lengths}, tensor_type="pt")
|
277 |
|
278 |
|
279 |
+
class GigaAMTokenizer(Wav2Vec2CTCTokenizer):
|
280 |
"""
|
281 |
+
Char tokenizer for GigaAM model.
|
282 |
"""
|
283 |
def __init__(
|
284 |
self,
|
|
|
303 |
|
304 |
class GigaAMProcessor(Wav2Vec2Processor):
|
305 |
feature_extractor_class = "GigaAMFeatureExtractor"
|
306 |
+
tokenizer_class = "GigaAMTokenizer"
|
307 |
|
308 |
def __init__(self, feature_extractor, tokenizer):
|
309 |
# super().__init__(feature_extractor, tokenizer)
|
|
|
315 |
@classmethod
|
316 |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
317 |
feature_extractor = GigaAMFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
318 |
+
tokenizer = GigaAMTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
319 |
|
320 |
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
321 |
|