Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Script to perform buffered inference using RNNT models. | |
| Buffered inference is the primary form of audio transcription when the audio segment is longer than 20-30 seconds. | |
| This is especially useful for models such as Conformers, which have quadratic time and memory scaling with | |
| audio duration. | |
| The difference between streaming and buffered inference is the chunk size (or the latency of inference). | |
| Buffered inference will use large chunk sizes (5-10 seconds) + some additional buffer for context. | |
| Streaming inference will use small chunk sizes (0.1 to 0.25 seconds) + some additional buffer for context. | |
| # Middle Token merge algorithm | |
| python speech_to_text_buffered_infer_rnnt.py \ | |
| model_path=null \ | |
| pretrained_name=null \ | |
| audio_dir="<remove or path to folder of audio files>" \ | |
| dataset_manifest="<remove or path to manifest>" \ | |
| output_filename="<remove or specify output filename>" \ | |
| total_buffer_in_secs=4.0 \ | |
| chunk_len_in_secs=1.6 \ | |
| model_stride=4 \ | |
| batch_size=32 | |
| # Longer Common Subsequence (LCS) Merge algorithm | |
| python speech_to_text_buffered_infer_rnnt.py \ | |
| model_path=null \ | |
| pretrained_name=null \ | |
| audio_dir="<remove or path to folder of audio files>" \ | |
| dataset_manifest="<remove or path to manifest>" \ | |
| output_filename="<remove or specify output filename>" \ | |
| total_buffer_in_secs=4.0 \ | |
| chunk_len_in_secs=1.6 \ | |
| model_stride=4 \ | |
| batch_size=32 \ | |
| merge_algo="lcs" \ | |
| lcs_alignment_dir=<OPTIONAL: Some path to store the LCS alignments> | |
| # NOTE: | |
| You can use `DEBUG=1 python speech_to_text_buffered_infer_ctc.py ...` to print out the | |
| predictions of the model, and ground-truth text if presents in manifest. | |
| """ | |
| import copy | |
| import glob | |
| import math | |
| import os | |
| from dataclasses import dataclass, is_dataclass | |
| from typing import Optional | |
| import torch | |
| from omegaconf import OmegaConf, open_dict | |
| from nemo.collections.asr.parts.utils.streaming_utils import ( | |
| BatchedFrameASRRNNT, | |
| LongestCommonSubsequenceBatchedFrameASRRNNT, | |
| ) | |
| from nemo.collections.asr.parts.utils.transcribe_utils import ( | |
| compute_output_filename, | |
| get_buffered_pred_feat_rnnt, | |
| setup_model, | |
| write_transcription, | |
| ) | |
| from nemo.core.config import hydra_runner | |
| from nemo.utils import logging | |
| can_gpu = torch.cuda.is_available() | |
| class TranscriptionConfig: | |
| # Required configs | |
| model_path: Optional[str] = None # Path to a .nemo file | |
| pretrained_name: Optional[str] = None # Name of a pretrained model | |
| audio_dir: Optional[str] = None # Path to a directory which contains audio files | |
| dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest | |
| # General configs | |
| output_filename: Optional[str] = None | |
| batch_size: int = 32 | |
| num_workers: int = 0 | |
| append_pred: bool = False # Sets mode of work, if True it will add new field transcriptions. | |
| pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one. | |
| # Chunked configs | |
| chunk_len_in_secs: float = 1.6 # Chunk length in seconds | |
| total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds | |
| model_stride: int = 8 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models", | |
| # Set `cuda` to int to define CUDA device. If 'None', will look for CUDA | |
| # device anyway, and do inference on CPU only if CUDA device is not found. | |
| # If `cuda` is a negative number, inference will be on CPU only. | |
| cuda: Optional[int] = None | |
| audio_type: str = "wav" | |
| # Recompute model transcription, even if the output folder exists with scores. | |
| overwrite_transcripts: bool = True | |
| # Decoding configs | |
| max_steps_per_timestep: int = 5 #'Maximum number of tokens decoded per acoustic timestep' | |
| stateful_decoding: bool = False # Whether to perform stateful decoding | |
| # Merge algorithm for transducers | |
| merge_algo: Optional[str] = 'middle' # choices=['middle', 'lcs'], choice of algorithm to apply during inference. | |
| lcs_alignment_dir: Optional[str] = None # Path to a directory to store LCS algo alignments | |
| def main(cfg: TranscriptionConfig) -> TranscriptionConfig: | |
| logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') | |
| torch.set_grad_enabled(False) | |
| if is_dataclass(cfg): | |
| cfg = OmegaConf.structured(cfg) | |
| if cfg.model_path is None and cfg.pretrained_name is None: | |
| raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!") | |
| if cfg.audio_dir is None and cfg.dataset_manifest is None: | |
| raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!") | |
| filepaths = None | |
| manifest = cfg.dataset_manifest | |
| if cfg.audio_dir is not None: | |
| filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"**/*.{cfg.audio_type}"), recursive=True)) | |
| manifest = None # ignore dataset_manifest if audio_dir and dataset_manifest both presents | |
| # setup GPU | |
| if cfg.cuda is None: | |
| if torch.cuda.is_available(): | |
| device = [0] # use 0th CUDA device | |
| accelerator = 'gpu' | |
| else: | |
| device = 1 | |
| accelerator = 'cpu' | |
| else: | |
| device = [cfg.cuda] | |
| accelerator = 'gpu' | |
| map_location = torch.device('cuda:{}'.format(device[0]) if accelerator == 'gpu' else 'cpu') | |
| logging.info(f"Inference will be done on device : {device}") | |
| asr_model, model_name = setup_model(cfg, map_location) | |
| model_cfg = copy.deepcopy(asr_model._cfg) | |
| OmegaConf.set_struct(model_cfg.preprocessor, False) | |
| # some changes for streaming scenario | |
| model_cfg.preprocessor.dither = 0.0 | |
| model_cfg.preprocessor.pad_to = 0 | |
| if model_cfg.preprocessor.normalize != "per_feature": | |
| logging.error("Only EncDecRNNTBPEModel models trained with per_feature normalization are supported currently") | |
| # Disable config overwriting | |
| OmegaConf.set_struct(model_cfg.preprocessor, True) | |
| # Compute output filename | |
| cfg = compute_output_filename(cfg, model_name) | |
| # if transcripts should not be overwritten, and already exists, skip re-transcription step and return | |
| if not cfg.overwrite_transcripts and os.path.exists(cfg.output_filename): | |
| logging.info( | |
| f"Previous transcripts found at {cfg.output_filename}, and flag `overwrite_transcripts`" | |
| f"is {cfg.overwrite_transcripts}. Returning without re-transcribing text." | |
| ) | |
| return cfg | |
| asr_model.freeze() | |
| asr_model = asr_model.to(asr_model.device) | |
| # Change Decoding Config | |
| decoding_cfg = asr_model.cfg.decoding | |
| with open_dict(decoding_cfg): | |
| if cfg.stateful_decoding: | |
| decoding_cfg.strategy = "greedy" | |
| else: | |
| decoding_cfg.strategy = "greedy_batch" | |
| decoding_cfg.preserve_alignments = True # required to compute the middle token for transducers. | |
| decoding_cfg.fused_batch_size = -1 # temporarily stop fused batch during inference. | |
| asr_model.change_decoding_strategy(decoding_cfg) | |
| feature_stride = model_cfg.preprocessor['window_stride'] | |
| model_stride_in_secs = feature_stride * cfg.model_stride | |
| total_buffer = cfg.total_buffer_in_secs | |
| chunk_len = float(cfg.chunk_len_in_secs) | |
| tokens_per_chunk = math.ceil(chunk_len / model_stride_in_secs) | |
| mid_delay = math.ceil((chunk_len + (total_buffer - chunk_len) / 2) / model_stride_in_secs) | |
| logging.info(f"tokens_per_chunk is {tokens_per_chunk}, mid_delay is {mid_delay}") | |
| if cfg.merge_algo == 'middle': | |
| frame_asr = BatchedFrameASRRNNT( | |
| asr_model=asr_model, | |
| frame_len=chunk_len, | |
| total_buffer=cfg.total_buffer_in_secs, | |
| batch_size=cfg.batch_size, | |
| max_steps_per_timestep=cfg.max_steps_per_timestep, | |
| stateful_decoding=cfg.stateful_decoding, | |
| ) | |
| elif cfg.merge_algo == 'lcs': | |
| frame_asr = LongestCommonSubsequenceBatchedFrameASRRNNT( | |
| asr_model=asr_model, | |
| frame_len=chunk_len, | |
| total_buffer=cfg.total_buffer_in_secs, | |
| batch_size=cfg.batch_size, | |
| max_steps_per_timestep=cfg.max_steps_per_timestep, | |
| stateful_decoding=cfg.stateful_decoding, | |
| alignment_basepath=cfg.lcs_alignment_dir, | |
| ) | |
| # Set the LCS algorithm delay. | |
| frame_asr.lcs_delay = math.floor(((total_buffer - chunk_len)) / model_stride_in_secs) | |
| else: | |
| raise ValueError("Invalid choice of merge algorithm for transducer buffered inference.") | |
| hyps = get_buffered_pred_feat_rnnt( | |
| asr=frame_asr, | |
| tokens_per_chunk=tokens_per_chunk, | |
| delay=mid_delay, | |
| model_stride_in_secs=model_stride_in_secs, | |
| batch_size=cfg.batch_size, | |
| manifest=manifest, | |
| filepaths=filepaths, | |
| ) | |
| output_filename = write_transcription(hyps, cfg, model_name, filepaths=filepaths, compute_langs=False) | |
| logging.info(f"Finished writing predictions to {output_filename}!") | |
| return cfg | |
| if __name__ == '__main__': | |
| main() # noqa pylint: disable=no-value-for-parameter | |
