waveletdeboshir commited on
Commit
5a4f890
·
verified ·
1 Parent(s): 9b9aeff

Update gigaam_transformers.py

Browse files
Files changed (1) hide show
  1. gigaam_transformers.py +4 -4
gigaam_transformers.py CHANGED
@@ -276,9 +276,9 @@ class GigaAMFeatureExtractor(SequenceFeatureExtractor):
276
  return BatchFeature({"input_features": input_features, "input_lengths": input_lengths}, tensor_type="pt")
277
 
278
 
279
- class GigaAMCTCTokenizer(Wav2Vec2CTCTokenizer):
280
  """
281
- Char tokenizer for GigaAM-CTC model.
282
  """
283
  def __init__(
284
  self,
@@ -303,7 +303,7 @@ class GigaAMCTCTokenizer(Wav2Vec2CTCTokenizer):
303
 
304
  class GigaAMProcessor(Wav2Vec2Processor):
305
  feature_extractor_class = "GigaAMFeatureExtractor"
306
- tokenizer_class = "GigaAMCTCTokenizer"
307
 
308
  def __init__(self, feature_extractor, tokenizer):
309
  # super().__init__(feature_extractor, tokenizer)
@@ -315,7 +315,7 @@ class GigaAMProcessor(Wav2Vec2Processor):
315
  @classmethod
316
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
317
  feature_extractor = GigaAMFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
318
- tokenizer = GigaAMCTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
319
 
320
  return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
321
 
 
276
  return BatchFeature({"input_features": input_features, "input_lengths": input_lengths}, tensor_type="pt")
277
 
278
 
279
+ class GigaAMTokenizer(Wav2Vec2CTCTokenizer):
280
  """
281
+ Char tokenizer for GigaAM model.
282
  """
283
  def __init__(
284
  self,
 
303
 
304
  class GigaAMProcessor(Wav2Vec2Processor):
305
  feature_extractor_class = "GigaAMFeatureExtractor"
306
+ tokenizer_class = "GigaAMTokenizer"
307
 
308
  def __init__(self, feature_extractor, tokenizer):
309
  # super().__init__(feature_extractor, tokenizer)
 
315
  @classmethod
316
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
317
  feature_extractor = GigaAMFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
318
+ tokenizer = GigaAMTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
319
 
320
  return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
321