Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2023 The HuggingFace Inc. team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| import requests | |
| from flax import jax_utils | |
| from flax.core.frozen_dict import freeze | |
| from flax.training.common_utils import shard | |
| from jax.sharding import PartitionSpec as P | |
| from transformers import WhisperProcessor, is_tokenizers_available, WhisperFeatureExtractor, WhisperTokenizerFast | |
| from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE, WhisperTokenizer | |
| from transformers.pipelines.audio_utils import ffmpeg_read | |
| from transformers.utils import logging | |
| from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration | |
| from .partitioner import PjitPartitioner | |
| from .train_state import InferenceState | |
| logger = logging.get_logger(__name__) | |
| # 2D parameter and activation partitioning for DP | |
| logical_axis_rules_dp = ( | |
| ("batch", "data"), | |
| ("mlp", None), | |
| ("heads", None), | |
| ("vocab", None), | |
| ("embed", None), | |
| ("embed", None), | |
| ("joined_kv", None), | |
| ("kv", None), | |
| ("length", None), | |
| ("num_mel", None), | |
| ("channels", None), | |
| ) | |
| class FlaxWhisperPipline: | |
| def __init__( | |
| self, | |
| checkpoint="openai/whisper-large-v2", | |
| dtype=jnp.float32, | |
| batch_size=None, | |
| max_length=None, | |
| ): | |
| """ | |
| Args | |
| checkpoint (`str`, *optional*, defaults to `"openai/whisper-large-v2"): | |
| The Whisper checkpoint to use with the pipeline. Must be an available checkpoint on the Hugging Face Hub | |
| with Flax weights. | |
| dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): | |
| The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and | |
| `jax.numpy.bfloat16` (on TPUs). This can be used to enable half-precision inference on GPUs or TPUs. | |
| If specified all the computation will be performed with the given `dtype`. **Note that this only | |
| specifies the dtype of the computation and does not influence the dtype of model parameters.** | |
| batch_size (`int`, *optional*, defaults to the minimum per-device batch size, i.e. `jax.local_device_count()`): | |
| The batch size to be used in chunking transcription. Beneficial for transcribing long audio files. Passing | |
| a batch size in the `__init__` method will be superseded by any batch size passed to the `__call__` method. | |
| max_length (`int`, *optional*): | |
| The maximum numbers of tokens to generate. Defaults to `model.config.max_length`. | |
| """ | |
| self.checkpoint = checkpoint | |
| self.dtype = dtype | |
| self.processor = WhisperProcessor.from_pretrained(self.checkpoint) | |
| self.feature_extractor = self.processor.feature_extractor | |
| # potentially load fast tokenizer if available | |
| tokenizer_cls = WhisperTokenizerFast if is_tokenizers_available() else WhisperTokenizer | |
| self.tokenizer = tokenizer_cls.from_pretrained(checkpoint) | |
| self.model, self.params = FlaxWhisperForConditionalGeneration.from_pretrained( | |
| self.checkpoint, | |
| _do_init=False, | |
| dtype=self.dtype, | |
| ) | |
| self.max_length = max_length if max_length is not None else self.model.generation_config.max_length | |
| self.min_batch_size = jax.local_device_count() | |
| self.batch_size = ( | |
| batch_size if batch_size is not None else self.min_batch_size | |
| ) # we need a minimum of 1 batch per-device | |
| def generate(params, input_features, forced_decoder_ids, return_timestamps): | |
| output_ids = self.model.pipeline_generate( | |
| input_features, | |
| params=params, | |
| forced_decoder_ids=forced_decoder_ids, | |
| return_timestamps=return_timestamps, | |
| max_length=self.max_length, | |
| ) | |
| return output_ids | |
| # use pmap for DP by default - this is compatible on a Colab TPU v2 | |
| self.params = jax_utils.replicate(self.params) | |
| self.p_generate = jax.pmap( | |
| generate, "input_features", in_axes=(0, 0, None), out_axes=0, static_broadcasted_argnums=(3,) | |
| ) | |
| self.is_sharded = False | |
| def shard_params(self, num_mp_partitions=1, logical_axis_rules=logical_axis_rules_dp): | |
| def init_fn(): | |
| input_shape = (1, self.model.config.num_mel_bins, 2 * self.model.config.max_source_positions) | |
| input_features = jnp.zeros(input_shape, dtype="f4") | |
| input_features = input_features.at[(..., -1)].set(self.model.config.eos_token_id) | |
| decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4") | |
| decoder_attention_mask = jnp.ones_like(decoder_input_ids) | |
| batch_size, sequence_length = decoder_input_ids.shape | |
| decoder_position_ids = jnp.broadcast_to( | |
| jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) | |
| ) | |
| rng = jax.random.PRNGKey(0) | |
| init_params = self.model.module.init( | |
| rng, | |
| input_features=input_features, | |
| decoder_input_ids=decoder_input_ids, | |
| decoder_attention_mask=decoder_attention_mask, | |
| decoder_position_ids=decoder_position_ids, | |
| return_dict=False, | |
| ) | |
| return init_params | |
| # Axis names metadata | |
| param_axes = jax.eval_shape(init_fn)["params_axes"] | |
| # Create InferenceState, since the partitioner expects it | |
| state = InferenceState( | |
| step=jnp.array(0), | |
| params=freeze(self.model.params_shape_tree), | |
| params_axes=freeze(param_axes), | |
| flax_mutables=None, | |
| flax_mutables_axes=param_axes, | |
| ) | |
| partitioner = PjitPartitioner(num_partitions=num_mp_partitions, logical_axis_rules=logical_axis_rules) | |
| mesh_axes = partitioner.get_mesh_axes(state) | |
| params_spec = mesh_axes.params | |
| p_shard_params = partitioner.partition(self.model.to_bf16, (params_spec,), params_spec) | |
| # This will auto-magically run in mesh context | |
| self.params = p_shard_params(freeze(jax_utils.unreplicate(self.params))) | |
| self.is_sharded = True | |
| def generate(params, input_features, forced_decoder_ids, return_timestamps): | |
| output_ids = self.model.pipeline_generate( | |
| input_features, | |
| params=params, | |
| forced_decoder_ids=forced_decoder_ids, | |
| return_timestamps=return_timestamps, | |
| max_length=self.max_length, | |
| ) | |
| return output_ids | |
| # Use pjit for generate only once we've sharded the params | |
| self.p_generate = partitioner.partition( | |
| generate, | |
| in_axis_resources=(params_spec, P("data"), None), | |
| out_axis_resources=P("data"), | |
| static_argnums=(3,), | |
| ) | |
| def generate(self, input_features, language=None, task=None, return_timestamps=False): | |
| forced_decoder_ids = self.get_forced_decoder_ids( | |
| language=language, task=task, return_timestamps=return_timestamps | |
| ) | |
| if not self.is_sharded: | |
| # if we're using pmap we need to manually replicate the input data across devices and gather the output tokens | |
| output_ids = self.p_generate( | |
| freeze(self.params), shard(input_features), forced_decoder_ids, return_timestamps | |
| ).sequences | |
| output_ids = jax.device_get(output_ids.reshape(-1, self.max_length)) | |
| else: | |
| # pjit handles replication / gathering for us auto-magically | |
| output_ids = self.p_generate( | |
| freeze(self.params), input_features, forced_decoder_ids, return_timestamps | |
| ).sequences | |
| return output_ids | |
| def get_forced_decoder_ids(self, generation_config=None, task=None, language=None, return_timestamps=False): | |
| if generation_config is None: | |
| generation_config = self.model.generation_config | |
| if hasattr(generation_config, "is_multilingual"): | |
| is_multilingual = generation_config.is_multilingual | |
| else: | |
| is_multilingual = None | |
| forced_decoder_ids = [] | |
| if is_multilingual: | |
| if language is not None: | |
| language = language.lower() | |
| if language in generation_config.lang_to_id.keys(): | |
| language_token = language | |
| elif language in TO_LANGUAGE_CODE.values(): | |
| language_token = f"<|{language}|>" | |
| elif language in TO_LANGUAGE_CODE.keys(): | |
| language_token = f"<|{TO_LANGUAGE_CODE[language]}|>" | |
| else: | |
| if len(language) == 2: | |
| # ISO 639-1 language code | |
| acceptable_languages = list(TO_LANGUAGE_CODE.values()) | |
| elif "<" in language or "|" in language or ">" in language: | |
| # generation config language code | |
| acceptable_languages = list(generation_config.lang_to_id.keys()) | |
| else: | |
| # language passed as a string | |
| acceptable_languages = list(TO_LANGUAGE_CODE.keys()) | |
| raise ValueError( | |
| f"Unsupported language: {language}. Language should be one of:" f" {acceptable_languages}." | |
| ) | |
| forced_decoder_ids.append((1, generation_config.lang_to_id[language_token])) | |
| if task is not None: | |
| forced_decoder_ids.append((2, generation_config.task_to_id[task])) | |
| else: | |
| forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) | |
| if not return_timestamps: | |
| if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id: | |
| idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 | |
| forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) | |
| return forced_decoder_ids | |
| def chunk_iter_with_batch(self, inputs, chunk_len, stride_left, stride_right, batch_size): | |
| inputs_len = inputs.shape[0] | |
| step = chunk_len - stride_left - stride_right | |
| all_chunk_start_idx = np.arange(0, inputs_len, step) | |
| num_samples = len(all_chunk_start_idx) | |
| num_batches = math.ceil(num_samples / batch_size) | |
| batch_idx = np.array_split(np.arange(num_samples), num_batches) | |
| for idx in batch_idx: | |
| chunk_start_idx = all_chunk_start_idx[idx] | |
| chunk_end_idx = chunk_start_idx + chunk_len | |
| chunks = [inputs[chunk_start:chunk_end] for chunk_start, chunk_end in zip(chunk_start_idx, chunk_end_idx)] | |
| processed = self.feature_extractor( | |
| chunks, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np" | |
| ) | |
| _stride_left = np.where(chunk_start_idx == 0, 0, stride_left) | |
| is_last = np.where(stride_right > 0, chunk_end_idx > inputs_len, chunk_end_idx >= inputs_len) | |
| _stride_right = np.where(is_last, 0, stride_right) | |
| chunk_lens = [chunk.shape[0] for chunk in chunks] | |
| strides = [ | |
| (chunk_l, _stride_l, _stride_r) | |
| for chunk_l, _stride_l, _stride_r in zip(chunk_lens, _stride_left, _stride_right) | |
| ] | |
| yield {"stride": strides, **processed} | |
| def preprocess_batch(self, inputs, chunk_length_s=30.0, stride_length_s=None, batch_size=None): | |
| if isinstance(inputs, np.ndarray): | |
| logger.warning( | |
| "Numpy array passed as input - no sampling rate checks will be performed." | |
| "It is strongly recommended to pass the input as a dictionary with an 'array' key " | |
| "containing the numpy array representing the audio, and a 'sampling_rate' key " | |
| "containing the sampling rate associated with the audio array." | |
| "Failing to do so can result in silent errors that might be hard to debug." | |
| ) | |
| if isinstance(inputs, str): | |
| if inputs.startswith("http://") or inputs.startswith("https://"): | |
| # We need to actually check for a real protocol, otherwise it's impossible to use a local file | |
| # like http_huggingface_co.png | |
| inputs = requests.get(inputs).content | |
| else: | |
| with open(inputs, "rb") as f: | |
| inputs = f.read() | |
| if isinstance(inputs, bytes): | |
| inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate) | |
| stride = None | |
| if isinstance(inputs, dict): | |
| stride = inputs.get("stride", None) | |
| # Accepting `"array"` which is the key defined in `datasets` for | |
| # better integration | |
| if not ("sampling_rate" in inputs and "array" in inputs): | |
| raise ValueError( | |
| "When passing a dictionary to FlaxWhisperPipline, the dict needs to contain an 'array' key " | |
| "containing the numpy array representing the audio, and a 'sampling_rate' key " | |
| "containing the sampling rate associated with the audio array." | |
| ) | |
| in_sampling_rate = inputs.get("sampling_rate") | |
| inputs = inputs.get("array", None) | |
| if in_sampling_rate != self.feature_extractor.sampling_rate: | |
| try: | |
| import librosa | |
| except ImportError as err: | |
| raise ImportError( | |
| "To support resampling audio files, please install 'librosa' and 'soundfile'." | |
| ) from err | |
| inputs = librosa.resample( | |
| inputs, orig_sr=in_sampling_rate, target_sr=self.feature_extractor.sampling_rate | |
| ) | |
| ratio = self.feature_extractor.sampling_rate / in_sampling_rate | |
| else: | |
| ratio = 1 | |
| if not isinstance(inputs, np.ndarray): | |
| raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`") | |
| if len(inputs.shape) != 1: | |
| raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline") | |
| if stride is not None: | |
| if stride[0] + stride[1] > inputs.shape[0]: | |
| raise ValueError("Stride is too large for input") | |
| # Stride needs to get the chunk length here, it's going to get | |
| # swallowed by the `feature_extractor` later, and then batching | |
| # can add extra data in the inputs, so we need to keep track | |
| # of the original length in the stride so we can cut properly. | |
| stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio))) | |
| if chunk_length_s: | |
| if stride_length_s is None: | |
| stride_length_s = chunk_length_s / 6 | |
| if isinstance(stride_length_s, (int, float)): | |
| stride_length_s = [stride_length_s, stride_length_s] | |
| chunk_len = round(chunk_length_s * self.feature_extractor.sampling_rate) | |
| stride_left = round(stride_length_s[0] * self.feature_extractor.sampling_rate) | |
| stride_right = round(stride_length_s[1] * self.feature_extractor.sampling_rate) | |
| if chunk_len < stride_left + stride_right: | |
| raise ValueError("Chunk length must be superior to stride length") | |
| for item in self.chunk_iter_with_batch( | |
| inputs, | |
| chunk_len, | |
| stride_left, | |
| stride_right, | |
| batch_size, | |
| ): | |
| yield item | |
| else: | |
| processed = self.feature_extractor( | |
| inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np" | |
| ) | |
| if stride is not None: | |
| processed["stride"] = stride | |
| yield processed | |
| def postprocess(self, model_outputs, return_timestamps=None, return_language=None): | |
| # unpack the outputs from list(dict(list)) to list(dict) | |
| model_outputs = [dict(zip(output, t)) for output in model_outputs for t in zip(*output.values())] | |
| time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions | |
| # Send the chunking back to seconds, it's easier to handle in whisper | |
| sampling_rate = self.feature_extractor.sampling_rate | |
| for output in model_outputs: | |
| if "stride" in output: | |
| chunk_len, stride_left, stride_right = output["stride"] | |
| # Go back in seconds | |
| chunk_len /= sampling_rate | |
| stride_left /= sampling_rate | |
| stride_right /= sampling_rate | |
| output["stride"] = chunk_len, stride_left, stride_right | |
| text, optional = self.tokenizer._decode_asr( | |
| model_outputs, | |
| return_timestamps=return_timestamps, | |
| return_language=return_language, | |
| time_precision=time_precision, | |
| ) | |
| return {"text": text, **optional} | |
| def forward(self, model_inputs, batch_size=None, language=None, task=None, return_timestamps=False): | |
| # We need to keep track of some additional input arguments for post-processing so need to forward these on after running generation | |
| input_features = model_inputs.pop("input_features") | |
| input_batch_size = input_features.shape[0] | |
| if input_batch_size != batch_size: | |
| padding = np.zeros([batch_size - input_batch_size, *input_features.shape[1:]], input_features.dtype) | |
| input_features = np.concatenate([input_features, padding]) | |
| pred_ids = self.generate(input_features, language=language, task=task, return_timestamps=return_timestamps)[ | |
| :input_batch_size | |
| ] | |
| # tokenizer's decode method expects an extra dim - we insert it here for convenience | |
| out = {"tokens": pred_ids[:, None, :]} | |
| stride = model_inputs.pop("stride", None) | |
| if stride is not None: | |
| out["stride"] = stride | |
| return out | |
| def __call__( | |
| self, | |
| inputs, | |
| chunk_length_s=30.0, | |
| stride_length_s=None, | |
| batch_size=None, | |
| language=None, | |
| task=None, | |
| return_timestamps=None, | |
| generate_kwargs=None, | |
| ): | |
| """ | |
| Transcribe an audio input sequence to a text transcription, optionally with timestamps. | |
| Args: | |
| inputs (`np.ndarray` or `bytes` or `str` or `dict`): | |
| The inputs is either: | |
| - `str` that is the filename of the audio file, the file will be read at the correct sampling rate | |
| to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system. | |
| - `bytes` is the byte content of an audio file and is interpreted by *ffmpeg* in the | |
| same way. | |
| - (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`) | |
| Raw audio assumed to be at the correct sampling rate (16kHz). Note that no further sampling | |
| rate check will be done. | |
| - `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this | |
| pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "array": | |
| np.array}`. Optionally an additional argument `"stride": (left: int, right: int)` can be used to | |
| ask the pipeline to treat the first `left` samples and last `right` samples to be ignored in | |
| decoding (but used at inference to provide more context to the model). In general, this additional | |
| stride argument is not required. | |
| chunk_length_s (`float`, *optional*, defaults to 30.0): | |
| The input length for each chunk. If `chunk_length_s = 0` then chunking is disabled. By default, the chunk | |
| length is set 30.0s, equal to Whisper's context window. | |
| stride_length_s (`float`, *optional*, defaults to `chunk_length_s / 6`): | |
| The length of stride on the left and right of each chunk. Used only with `chunk_length_s > 0`. This enables | |
| the model to *see* more context and infer letters better than without this context but the pipeline | |
| discards the stride bits at the end to make the final reconstitution as perfect as possible. | |
| <Tip> | |
| For more information on how to effectively use `stride_length_s`, refer to the [ASR chunking | |
| blog post](https://huggingface.co/blog/asr-chunking). | |
| </Tip> | |
| batch_size (`int`, *optional*, defaults to the minimum per-device batch size, i.e. `jax.local_device_count()`): | |
| The batch size to be used in chunking transcription. Beneficial for transcribing long audio files. Passing | |
| a batch size in the `__call__` method will supersede any batch size passed to the `__init__`. | |
| task (`str`, *optional*): | |
| Task to use for generation, either `"transcribe"` or `"translate"`. Defaults to `"transcribe"`. | |
| language (`str`, *optional*): | |
| Language token to use for generation, can be either in the form of `"<|en|>"`, `"en"` or `"english"`. | |
| Defaults to `None`, meaning the language is automatically inferred from the audio input. | |
| return_timestamps (*optional*, `bool`): | |
| Whether to return timestamps in the prediction. Defaults to False. If set to true, the pipeline | |
| will return two keys in the output dictionary: `"text"` containing the text transcription, and `"chunks"` | |
| containing the transcription segments chunked by their utterance-level timestamps. | |
| Return: | |
| `Dict`: A dictionary with the following keys: | |
| - **text** (`str` ) -- The recognised text. | |
| - **chunks** (*optional(, `List[Dict]`) | |
| When using `return_timestamps`, the `chunks` will become a list containing all the various text | |
| chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text": | |
| "there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing | |
| `"".join(chunk["text"] for chunk in output["chunks"])`. | |
| """ | |
| batch_size = batch_size if batch_size is not None else self.batch_size | |
| if batch_size % self.min_batch_size != 0: | |
| raise ValueError( | |
| f"Batch size must be a multiple of the number of JAX devices, but got batch size {batch_size} and num devices {self.min_batch_size}." | |
| ) | |
| dataloader = self.preprocess_batch( | |
| inputs, chunk_length_s=chunk_length_s, stride_length_s=stride_length_s, batch_size=batch_size | |
| ) | |
| model_outputs = [] | |
| # iterate over our chunked audio samples | |
| for batch in dataloader: | |
| model_outputs.append( | |
| self.forward( | |
| batch, batch_size=batch_size, language=language, task=task, return_timestamps=return_timestamps | |
| ) | |
| ) | |
| post_processed = self.postprocess(model_outputs, return_timestamps=return_timestamps) | |
| return post_processed | |