GalaktischeGurke
commited on
Saving train state of step 5000
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .gitignore +1 -0
- Makefile +9 -0
- README.md +638 -0
- added_tokens.json +1611 -0
- checkpoint-5000-epoch-0/model.safetensors +3 -0
- checkpoint-5000-epoch-0/model_1.safetensors +3 -0
- checkpoint-5000-epoch-0/optimizer.bin +3 -0
- checkpoint-5000-epoch-0/random_states_0.pkl +3 -0
- checkpoint-5000-epoch-0/random_states_1.pkl +3 -0
- checkpoint-5000-epoch-0/scheduler.bin +3 -0
- config.json +50 -0
- core +3 -0
- create_student_model.py +231 -0
- distil-large-v3-init/added_tokens.json +1611 -0
- distil-large-v3-init/config.json +50 -0
- distil-large-v3-init/generation_config.json +255 -0
- distil-large-v3-init/merges.txt +0 -0
- distil-large-v3-init/model.safetensors +3 -0
- distil-large-v3-init/normalizer.json +1742 -0
- distil-large-v3-init/preprocessor_config.json +14 -0
- distil-large-v3-init/special_tokens_map.json +139 -0
- distil-large-v3-init/tokenizer_config.json +0 -0
- distil-large-v3-init/vocab.json +0 -0
- distil_whisper.egg-info/PKG-INFO +655 -0
- distil_whisper.egg-info/SOURCES.txt +8 -0
- distil_whisper.egg-info/dependency_links.txt +1 -0
- distil_whisper.egg-info/requires.txt +12 -0
- distil_whisper.egg-info/top_level.txt +1 -0
- flax/LICENSE +201 -0
- flax/Makefile +9 -0
- flax/README.md +293 -0
- flax/conversion_scripts/run_convert_distilled_train_state_to_hf.sh +8 -0
- flax/convert_train_state_to_hf.py +327 -0
- flax/create_student_model.py +226 -0
- flax/distil_whisper/__init__.py +21 -0
- flax/distil_whisper/layers.py +1338 -0
- flax/distil_whisper/modeling_flax_whisper.py +2135 -0
- flax/distil_whisper/partitioner.py +965 -0
- flax/distil_whisper/pipeline.py +527 -0
- flax/distil_whisper/train_state.py +118 -0
- flax/distillation_scripts/run_32_2_pt.sh +38 -0
- flax/distillation_scripts/run_bs_sweep.yaml +67 -0
- flax/distillation_scripts/run_dataset_sweep.yaml +77 -0
- flax/distillation_scripts/run_decoder_sweep.yaml +72 -0
- flax/distillation_scripts/run_distillation_12_2_timestamped.sh +42 -0
- flax/distillation_scripts/run_distillation_15s_context.sh +43 -0
- flax/distillation_scripts/run_distillation_16_2.sh +41 -0
- flax/distillation_scripts/run_distillation_24_2.sh +42 -0
- flax/distillation_scripts/run_distillation_24_2_timestamped.sh +42 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
core filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
wandb
|
Makefile
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
check_dirs := .
|
2 |
+
|
3 |
+
quality:
|
4 |
+
black --check $(check_dirs)
|
5 |
+
ruff $(check_dirs)
|
6 |
+
|
7 |
+
style:
|
8 |
+
black $(check_dirs)
|
9 |
+
ruff $(check_dirs) --fix
|
README.md
ADDED
@@ -0,0 +1,638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Training Distil-Whisper
|
2 |
+
|
3 |
+
This sub-folder contains all the scripts required to train a Distil-Whisper model in your choice of language. They are
|
4 |
+
slightly modified from the original scripts used to distill Whisper for English ASR (as-per the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430)).
|
5 |
+
The main difference is that these scripts are written in [PyTorch](https://pytorch.org), whereas the original scripts
|
6 |
+
are in [JAX](https://jax.readthedocs.io/en/latest/#)/[Flax](https://flax.readthedocs.io/en/latest/). These scripts are
|
7 |
+
also made to be easier to run end-to-end, whereas the original scripts require more steps and are somewhat hard-coded
|
8 |
+
for English ASR. Both sets of scripts achieve equivalent downstream results when the hyper-parameters are set equal.
|
9 |
+
|
10 |
+
If you are interested in reproducing the original Distil-Whisper checkpoints, we refer you to the sub-folder [Flax Training](./flax/README.md).
|
11 |
+
Otherwise, if you wish to distill Whisper on your own language/dataset, we recommend you use these scripts for ease of use
|
12 |
+
and the configurability they provide.
|
13 |
+
|
14 |
+
Reproducing the Distil-Whisper project requires four stages to be completed in successive order:
|
15 |
+
|
16 |
+
1. [Pseudo-labelling](#1-pseudo-labelling)
|
17 |
+
2. [Initialisation](#2-initialisation)
|
18 |
+
3. [Training](#3-training)
|
19 |
+
4. [Evaluation](#4-evaluation)
|
20 |
+
|
21 |
+
This README is partitioned according to the four stages. Each section provides a minimal example for running the
|
22 |
+
scripts used in the project. We will use a running example of distilling the Whisper model for Hindi speech recognition
|
23 |
+
on the Common Voice dataset. Note that this dataset only contains ~20 hours of audio data. Thus, it can be run extremely
|
24 |
+
quickly, but does not provide sufficient data to achieve optimal performance. We recommend training on upwards of 1000
|
25 |
+
hours of data should you want to match the performance of Whisper on high-resource languages.
|
26 |
+
|
27 |
+
## Requirements
|
28 |
+
|
29 |
+
The Distil-Whisper training code is written in [PyTorch](https://pytorch.org) and [Accelerate](https://huggingface.co/docs/accelerate/index).
|
30 |
+
It heavily leverages the Whisper implementation in [🤗 Transformers](https://github.com/huggingface/transformers) for both
|
31 |
+
training and inference.
|
32 |
+
|
33 |
+
The instructions for installing the package are as follows:
|
34 |
+
1. Install PyTorch from the [official instructions](https://pytorch.org/get-started/locally/), ensuring you install the correct version for your hardware and CUDA version.
|
35 |
+
2. Fork the `distil-whisper` repository by clicking on the [fork](https://github.com/huggingface/distil-whisper/fork) button on the reopsitory's page
|
36 |
+
3. Clone the `distil-whisper` repository and add the base repository as a remote. This will allow you to "pull" any upstream changes that are made to the base repository:
|
37 |
+
|
38 |
+
```bash
|
39 |
+
git clone https://github.com/<your GitHub handle>/distil-whisper.git
|
40 |
+
cd distil-whisper
|
41 |
+
git remote add upstream https://github.com/huggingface/distil-whisper.git
|
42 |
+
```
|
43 |
+
4. pip install the required packages from the [setup.py](./setup.py) file:
|
44 |
+
```bash
|
45 |
+
cd training
|
46 |
+
pip install -e .
|
47 |
+
cd ../..
|
48 |
+
```
|
49 |
+
|
50 |
+
5. Configure Accelerate by running the following command. Note that you should set the number of GPUs you wish to use for distillation, and also the data type (dtype) to your preferred dtype for training/inference (e.g. `bfloat16` on A100 GPUs, `float16` on V100 GPUs, etc.):
|
51 |
+
|
52 |
+
```bash
|
53 |
+
accelerate config
|
54 |
+
```
|
55 |
+
|
56 |
+
6. The last thing we need to do is link our Hugging Face account so that we can pull/push model repositories on the Hub. This will allow us to save our final distilled weights on the Hub so that we can share them with the community. Run the command:
|
57 |
+
|
58 |
+
```bash
|
59 |
+
git config --global credential.helper store
|
60 |
+
huggingface-cli login
|
61 |
+
```
|
62 |
+
And then enter an authentication token from https://huggingface.co/settings/tokens. Create a new token if you do not have one already. You should make sure that this token has "write" privileges.
|
63 |
+
|
64 |
+
To confirm that you have a working environment, first accept the terms of use of the Common Voice 16.1 dataset on the Hub: https://huggingface.co/datasets/mozilla-foundation/common_voice_16_1
|
65 |
+
|
66 |
+
You can run the following code cell to stream one sample of data from the Common Voice dataset, and check that you can
|
67 |
+
perform inference using the "tiny" Whisper model:
|
68 |
+
|
69 |
+
```python
|
70 |
+
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
71 |
+
from datasets import load_dataset, Audio
|
72 |
+
|
73 |
+
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny", low_cpu_mem_usage=True)
|
74 |
+
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
75 |
+
|
76 |
+
model.to("cuda")
|
77 |
+
|
78 |
+
common_voice = load_dataset("mozilla-foundation/common_voice_16_1", "en", split="validation", streaming=True)
|
79 |
+
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
|
80 |
+
|
81 |
+
inputs = processor(next(iter(common_voice))["audio"]["array"], sampling_rate=16000, return_tensors="pt")
|
82 |
+
input_features = inputs.input_features
|
83 |
+
|
84 |
+
generated_ids = model.generate(input_features.to("cuda"), max_new_tokens=128)
|
85 |
+
pred_text = processor.decode(generated_ids[0], skip_special_tokens=True)
|
86 |
+
|
87 |
+
print("Pred text:", pred_text)
|
88 |
+
print("Environment set up successful?", generated_ids.shape[-1] == 20)
|
89 |
+
```
|
90 |
+
|
91 |
+
## 1. Pseudo-Labelling
|
92 |
+
|
93 |
+
The python script [`run_pseudo_labelling.py`](run_pseudo_labelling.py) is a flexible inference script that can be used
|
94 |
+
to generate pseudo-labels under a range of settings, including using both greedy and beam-search. It is also compatible
|
95 |
+
with [🤗 Datasets](https://github.com/huggingface/datasets) *streaming mode*, allowing users to load massive audio
|
96 |
+
datasets with **no disk space requirements**. For more information on streaming mode, the reader is referred to the
|
97 |
+
blog post: [A Complete Guide to Audio Datasets](https://huggingface.co/blog/audio-datasets#streaming-mode-the-silver-bullet).
|
98 |
+
|
99 |
+
> As of the latest Distil-Whisper release, [`distil-large-v3`](https://huggingface.co/distil-whisper/distil-large-v3), this
|
100 |
+
pseudo-labelling script also performs the added operation of concatenating (or packing) the audio inputs to 30-seconds.
|
101 |
+
Not only does this lead to a WER improvement when using sequential long-form decoding algorithm, but concatenating audios
|
102 |
+
to 30-seconds also improves the throughput during training, since the amount of zero-padding on the audio inputs is minimised.
|
103 |
+
|
104 |
+
The following script demonstrates how to pseudo-label the Hindi split of the Common Voice 16.1 dataset with greedy sampling:
|
105 |
+
|
106 |
+
```bash
|
107 |
+
#!/usr/bin/env bash
|
108 |
+
|
109 |
+
accelerate launch run_pseudo_labelling.py \
|
110 |
+
--model_name_or_path "openai/whisper-large-v3" \
|
111 |
+
--dataset_name "mozilla-foundation/common_voice_16_1" \
|
112 |
+
--dataset_config_name "hi" \
|
113 |
+
--dataset_split_name "train+validation+test" \
|
114 |
+
--text_column_name "sentence" \
|
115 |
+
--id_column_name "path" \
|
116 |
+
--output_dir "./common_voice_16_1_hi_pseudo_labelled" \
|
117 |
+
--wandb_project "distil-whisper-labelling" \
|
118 |
+
--per_device_eval_batch_size 64 \
|
119 |
+
--dtype "bfloat16" \
|
120 |
+
--attn_implementation "sdpa" \
|
121 |
+
--logging_steps 500 \
|
122 |
+
--max_label_length 256 \
|
123 |
+
--concatenate_audio \
|
124 |
+
--preprocessing_batch_size 500 \
|
125 |
+
--preprocessing_num_workers 8 \
|
126 |
+
--dataloader_num_workers 8 \
|
127 |
+
--report_to "wandb" \
|
128 |
+
--language "hi" \
|
129 |
+
--task "transcribe" \
|
130 |
+
--return_timestamps \
|
131 |
+
--streaming False \
|
132 |
+
--generation_num_beams 1 \
|
133 |
+
--push_to_hub
|
134 |
+
```
|
135 |
+
|
136 |
+
On an 80 GB A100 GPU, the following script takes approximately 5 minutes to concatenate and pre-process the 20 hours of
|
137 |
+
audio data, and a further 10 minutes to transcribe the pseudo-labels. The pseudo-labelled dataset corresponding to this
|
138 |
+
script is available on the Hugging Face Hub under [sanchit-gandhi/common_voice_16_1_hi_pseudo_labelled](https://huggingface.co/datasets/sanchit-gandhi/common_voice_16_1_hi_pseudo_labelled).
|
139 |
+
The WER of the pre-trained Whisper large-v3 model is 17.2% on the test split. We will compare the performance of our distilled model against this number.
|
140 |
+
|
141 |
+
There are two noteworthy arguments that configure the dataset concatenation (or packing) process:
|
142 |
+
1. `concatenate_audio`: whether or not to concatenate (or pack) the audios to 30-second chunks. The latest Distil-Whisper model, [`distil-large-v3`](https://huggingface.co/distil-whisper/distil-large-v3#differences-with-distil-large-v2), highlights the WER improvements obtained using the sequential long-form decoding algorithm when concatenated audios are used. Concatenating audios to 30-seconds also improves the throughput during training, since the amount of zero-padding on the audio inputs is minimised. Hence, it is highly recommended to set `--concatenate_audio=True`.
|
143 |
+
2. `preprocessing_batch_size`: the batch size to use when concatenating (or packing) the audios. Using a larger batch size results in a greater portion of audio samples being packed to 30-seconds, at the expense of higher memory consumption. If you exceed your system's RAM when performing the concatenation operation, reduce the `preprocessing_batch_size` by a factor of 2 to 250 or even 125.
|
144 |
+
3. `preprocessing_num_workers`: the number of multiprocessing workers to use when concatenating the audios. Using more workers will result in faster pre-processing, at the expense of higher memory consumption. Ensure you do not exceed the maximum number of CPUs on your device.
|
145 |
+
|
146 |
+
In addition, the following arguments configure the inference of the Whisper model:
|
147 |
+
1. `language`: explicitly setting the language token during inference substantially improves the generation performance of the Whisper model, since the model is forced always to predict in the given language. We recommend you set the language to the language you wish to distil the Whisper model on. The only exception is when distilling an English-only model (i.e. where the model id is appended with an `.en`, e.g. `small.en`), the language argument should be set to None, since there is no language token used during training/inference.
|
148 |
+
2. `return_timestamps`: whether or not to predict timestamps in the pseudo-labels. Timestamp prediction is required should you want your distilled model to be able to predict timestamps at inference time (e.g. for the original OpenAI long-form transcription algorithm). However, the pseudo-labels are marginally less accurate than not using timestamps. We recommend pseudo-labelling **with** timestamps to ensure the distilled model is as general as possible.
|
149 |
+
3. `attn_implementation`: which attention implementation to use for inference. Set to `sdpa` for [PyTorch SDPA](https://huggingface.co/docs/transformers/v4.35.2/en/perf_infer_gpu_one#bettertransformer), or `flash_attn_2` if your hardware supports Flash Attention 2 and you have the [package installed](https://github.com/Dao-AILab/flash-attention).
|
150 |
+
4. `streaming`: whether or not to use Datasets' streaming mode. If enabled, the audio data will be streamed from the Hugging Face Hub with no disk space requirements. However, the user is then responsible for adding the pseudo-labels to the dataset script in a follow-up step (see [Using Streaming Mode](#TODO)). If set to `False`, the audio data will be downloaded and pre-processed offline. At the end of pseudo-labelling, the pseudo-labels will be automatically appended to the original dataset, meaning the dataset is ready to be used for the subsequent training step without any additional steps.
|
151 |
+
5. `generation_num_beams`: how many beams to use while decoding. In practice, we found the distilled model to perform comparably when the data was pseudo-labelled with `generation_num_beams=1` (greedy) or `generation_num_beams>1` (beam). This is likely because the WER filter compensates for the lower quality pseudo-labels obtained using greedy search. However, using `generation_num_beams=1` gives substantially faster inference time for the pseudo-labelling step, and so we recommend this configuration.
|
152 |
+
|
153 |
+
Should you have your own audio dataset, you can first [convert it](https://huggingface.co/docs/datasets/audio_dataset) to
|
154 |
+
Hugging Face Datasets format and push it to the Hugging Face Hub. You can then pseudo-label it using the script above,
|
155 |
+
replacing the `--dataset_name` with the name of your dataset on the Hub.
|
156 |
+
|
157 |
+
Otherwise, you may wish to use an open-source dataset already available on the Hugging Face Hub. We provide a summary of
|
158 |
+
the three most popular multilingual datasets in the table below. For more details, refer to the blog post: [A Complete Guide to Audio Datasets](https://huggingface.co/blog/audio-datasets#multilingual-speech-recognition).
|
159 |
+
|
160 |
+
| Dataset | Languages | Domain | Speaking Style | License | Text Column | ID Column |
|
161 |
+
|-----------------------------------------------------------------------------------------------|-----------|---------------------------------------|----------------|-----------|---------------------|--------------|
|
162 |
+
| [Multilingual LibriSpeech](https://huggingface.co/datasets/facebook/multilingual_librispeech) | 6 | Audiobooks | Narrated | CC-BY-4.0 | `"text"` | `"id"` |
|
163 |
+
| [Common Voice 16](https://huggingface.co/datasets/mozilla-foundation/common_voice_16_1) | 120 | Wikipedia text & crowd-sourced speech | Narrated | CC0-1.0 | `"sentence"` | `"path"` |
|
164 |
+
| [VoxPopuli](https://huggingface.co/datasets/facebook/voxpopuli) | 15 | European Parliament recordings | Spontaneous | CC0 | `"normalized_text"` | `"audio_id"` |
|
165 |
+
|
166 |
+
To achieve *robustness* to different distributions of audio data, it is recommended to train on multiple datasets where possible.
|
167 |
+
For example, the above three datasets all have splits for the German language. Thus, if distilling a Whisper model for German,
|
168 |
+
it would be wise to use a combination of the three datasets during training, in order to cover at least three distinct domains
|
169 |
+
(audiobooks, crowd-sourced speech, parliament recordings). You may wish to use a combination of open-source datasets, or
|
170 |
+
a combination of open-source and individually owned datasets to cover multiple distributions and domains. Moreover, if you were to train on low-resource datasets (<500 hours), you could experiment with [language mixing](#3-language-mixing) to improve robustness.
|
171 |
+
|
172 |
+
## 2. Initialisation
|
173 |
+
|
174 |
+
The script [`create_student_model.py`](create_student_model.py) can be used to initialise a small student model
|
175 |
+
from a large teacher model. When initialising a student model with fewer layers than the teacher model, the student is
|
176 |
+
initialised by copying maximally spaced layers from the teacher, as per the [DistilBart](https://arxiv.org/abs/2010.13002)
|
177 |
+
recommendations.
|
178 |
+
|
179 |
+
First, we need to create a model repository on the Hugging Face Hub. This repository will contain all the required files
|
180 |
+
to reproduce the training run, alongside model weights, training logs and a README.md card. You can either create a model
|
181 |
+
repository directly on the Hugging Face Hub using the link: https://huggingface.co/new. Or, via the CLI, as we'll show here.
|
182 |
+
|
183 |
+
Let's pick a name for our distilled model: `distil-whisper-large-v3-hi`. We can run the following command to create a repository under this name:
|
184 |
+
|
185 |
+
```bash
|
186 |
+
huggingface-cli repo create distil-whisper-large-v3-hi
|
187 |
+
```
|
188 |
+
|
189 |
+
We can now see the model on the Hub, e.g. under https://huggingface.co/sanchit-gandhi/distil-whisper-large-v3-hi
|
190 |
+
|
191 |
+
Let's clone the repository so that we can place our training script and model weights inside:
|
192 |
+
|
193 |
+
```bash
|
194 |
+
git lfs install
|
195 |
+
git clone https://huggingface.co/sanchit-gandhi/distil-whisper-large-v3-hi
|
196 |
+
```
|
197 |
+
|
198 |
+
Be sure to change the repo address to `https://huggingface.co/<your-user-name>/<your-repo-name>`
|
199 |
+
|
200 |
+
We can now copy the relevant training scrips to the repository:
|
201 |
+
```bash
|
202 |
+
cd distil-whisper-large-v3-hi
|
203 |
+
|
204 |
+
cp ../distil-whisper/training/create_student_model.py .
|
205 |
+
cp ../distil-whisper/training/run_distillation.py .
|
206 |
+
```
|
207 |
+
|
208 |
+
The following command demonstrates how to initialise a student model from the Whisper [large-v3](https://huggingface.co/openai/whisper-large-v3)
|
209 |
+
checkpoint, with all 32 encoder layer and 2 decoder layers. The 2 student decoder layers are copied from teacher layers
|
210 |
+
1 and 32 respectively, as the maximally spaced layers:
|
211 |
+
|
212 |
+
```bash
|
213 |
+
#!/usr/bin/env bash
|
214 |
+
|
215 |
+
python create_student_model.py \
|
216 |
+
--teacher_checkpoint "openai/whisper-large-v3" \
|
217 |
+
--encoder_layers 32 \
|
218 |
+
--decoder_layers 2 \
|
219 |
+
--save_dir "./distil-large-v3-init"
|
220 |
+
```
|
221 |
+
|
222 |
+
The initialised model will be saved to the sub-directory `distil-large-v3-init` in our model repository.
|
223 |
+
|
224 |
+
|
225 |
+
**Note:** You can leverage language transfer by setting `--teacher_checkpoint` to "distil-whisper/distil-large-v3", see [language transfer](#22-language-transfer) for more details.
|
226 |
+
|
227 |
+
## 3. Training
|
228 |
+
|
229 |
+
The script [`run_distillation.py`](run_distillation.py) is an end-to-end script for loading multiple
|
230 |
+
datasets, a student model, a teacher model, and performing teacher-student distillation. It uses the loss formulation
|
231 |
+
from the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430), which is a weighted sum of the cross-entropy and
|
232 |
+
KL-divergence loss terms.
|
233 |
+
|
234 |
+
The following command takes the Common Voice dataset that was pseudo-labelled in the first stage and trains the
|
235 |
+
2-layer decoder model intialised in the previous step. We pass the local path to the pseudo-labelled Common Voice dataset
|
236 |
+
(`../common_voice_16_1_hi_pseudo_labelled`), which you can change to the path where your local pseudo-labelled dataset is
|
237 |
+
saved.
|
238 |
+
|
239 |
+
In this example, we will combine the train and validation splits to give our training set, and evaluate on the test split
|
240 |
+
only. This is purely to demonstrate how to combine multiple pseudo-labelled datasets for training, rather than recommended
|
241 |
+
advice for defining train/validation splits. We advise that you train on the train splits of your dataset, evaluate and
|
242 |
+
tune hyper-parameters on the validation split, and only test the final checkpoint on the test split. Note how multiple
|
243 |
+
training datasets and splits can be loaded by separating the dataset arguments by `+` symbols. Thus, the script generalises
|
244 |
+
to any number of training datasets.
|
245 |
+
|
246 |
+
```bash
|
247 |
+
#!/usr/bin/env bash
|
248 |
+
|
249 |
+
accelerate launch run_distillation.py \
|
250 |
+
--model_name_or_path "./distil-large-v3-init" \
|
251 |
+
--teacher_model_name_or_path "openai/whisper-large-v3" \
|
252 |
+
--train_dataset_name "../common_voice_16_1_hi_pseudo_labelled+../common_voice_16_1_hi_pseudo_labelled" \
|
253 |
+
--train_split_name "train+validation" \
|
254 |
+
--text_column_name "sentence+sentence" \
|
255 |
+
--train_dataset_samples "7+4" \
|
256 |
+
--eval_dataset_name "../common_voice_16_1_hi_pseudo_labelled" \
|
257 |
+
--eval_split_name "test" \
|
258 |
+
--eval_text_column_name "sentence" \
|
259 |
+
--eval_steps 1000 \
|
260 |
+
--save_steps 1000 \
|
261 |
+
--warmup_steps 50 \
|
262 |
+
--learning_rate 0.0001 \
|
263 |
+
--lr_scheduler_type "constant_with_warmup" \
|
264 |
+
--timestamp_probability 0.2 \
|
265 |
+
--condition_on_prev_probability 0.2 \
|
266 |
+
--language "hi" \
|
267 |
+
--task "transcribe" \
|
268 |
+
--logging_steps 25 \
|
269 |
+
--save_total_limit 1 \
|
270 |
+
--max_steps 5000 \
|
271 |
+
--wer_threshold 20 \
|
272 |
+
--per_device_train_batch_size 32 \
|
273 |
+
--per_device_eval_batch_size 32 \
|
274 |
+
--dataloader_num_workers 8 \
|
275 |
+
--preprocessing_num_workers 8 \
|
276 |
+
--ddp_timeout 7200 \
|
277 |
+
--dtype "bfloat16" \
|
278 |
+
--attn_implementation "sdpa" \
|
279 |
+
--output_dir "./" \
|
280 |
+
--do_train \
|
281 |
+
--do_eval \
|
282 |
+
--gradient_checkpointing \
|
283 |
+
--overwrite_output_dir \
|
284 |
+
--predict_with_generate \
|
285 |
+
--freeze_encoder \
|
286 |
+
--freeze_embed_positions \
|
287 |
+
--streaming False \
|
288 |
+
--push_to_hub
|
289 |
+
|
290 |
+
```
|
291 |
+
|
292 |
+
The above training script will take approximately 3 hours to complete on an 80 GB A100 GPU and yield a final WER of 76%.
|
293 |
+
While the generations are starting to take form, there is still a 59% WER gap to the teacher model. This is hardly
|
294 |
+
surprising give we only have 15 hours of un-filtered data, and closer to just 1.5 hours with data filtering.
|
295 |
+
As mentioned above, using upwards of 1000 hours of data and training for 10k steps will likely yield
|
296 |
+
more competitive performance. For the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430), we trained on 21k hours
|
297 |
+
of audio data for 80k steps. We found that upwards of 13k hours of audio data was required to reach convergence on English
|
298 |
+
ASR (see Section 9.2 of the [paper](https://arxiv.org/abs/2311.00430)), so the more data you have, the better!
|
299 |
+
|
300 |
+
Scaling to multiple GPUs using [distributed data parallelism (DDP)](https://pytorch.org/tutorials/beginner/ddp_series_theory.html)
|
301 |
+
is trivial: simply run `accelerate config` and select the multi-GPU option, specifying the IDs of the GPUs you wish to use. The
|
302 |
+
above script can then be run using DDP with no code changes.
|
303 |
+
|
304 |
+
Training logs will be reported to TensorBoard and WandB, provided the relevant packages are available. An example of a
|
305 |
+
saved checkpoint pushed to the Hugging Face Hub can be found here: [sanchit-gandhi/distil-whisper-large-v3-hi](https://huggingface.co/sanchit-gandhi/distil-whisper-large-v3-hi).
|
306 |
+
|
307 |
+
There are a few noteworthy data arguments:
|
308 |
+
1. `train_dataset_samples`: defines the number of training samples in each dataset. Used to calculate the sampling probabilities in the dataloader. A good starting point is setting the samples to the number of hours of audio data in each split. A more refined strategy is setting it to the number of training samples in each split, however this might require downloading the dataset offline to compute these statistics.
|
309 |
+
2. `wer_threshold`: sets the WER threshold between the normalised pseudo-labels and normalised ground truth labels. Any samples with WER > `wer_threshold` are discarded from the training data. This is beneficial to avoid training the student model on pseudo-labels where Whisper hallucinated or got the predictions grossly wrong. In our English distillation experiments, we found a WER threshold of 10% provides the optimal trade-off between ensuring high-quality transcriptions, and not filtering unnecessary amounts of training data. For multilingual distillation, the threshold should be set in accordance with the WER achieved by the pre-trained model on the test set.
|
310 |
+
3. `streaming`: whether or not to use Datasets' streaming mode. Recommended for large datasets, where the audio data can be streamed from the Hugging Face Hub with no disk space requirements.
|
311 |
+
4. `timestamp_probability`: the per-sample probability for retaining timestamp tokens in the labels (should they contain them). Retaining some portion of timestamp tokens in the training data is required to ensure the distilled model can predict timestamps at inference time. In our experiments, we found that training on timestamps with high-probability hurts the distilled model's transcription performance. Thus, we recommend setting this to a value below 0.5. Typically, a value of 0.2 works well, giving good transcription and timestamp performance.
|
312 |
+
5. `condition_on_prev_probability`: the per-sample probability for conditioning on previous labels. Conditioning on previous tokens is required to ensure the distilled model can be used with the "sequential" long-form transcription algorithm at inference time. We did not experiment with this parameter, but found values around 0.2 to provide adequate performance. OpenAI pre-trained Whisper on with a 50% probability for conditioning on previous tokens. Thus, you might wish to try higher values.
|
313 |
+
|
314 |
+
As well as a few noteworthy model arguments that can be configured to give optimal training performance:
|
315 |
+
1. `freeze_encoder`: whether to freeze the entire encoder of the student model during training. Beneficial when the student encoder is copied exactly from the teacher encoder. In this case, the encoder hidden-states from the teacher model are re-used for the student model. Stopping the gradient computation through the encoder and sharing the encoder hidden-states provides a significant memory saving, and can enable up to 2x batch sizes.
|
316 |
+
2. `freeze_embed_positions`: whether to freeze the student model's decoder positional embeddings. Using the same embed positions as the teacher model, which is designed to handle context lengths up to 448 tokens, helps the student model retain its input id representation up to the full max input length.
|
317 |
+
3. `dtype`: data type (dtype) in which the model computation should be performed. Note that this only controls the dtype of the computations (forward and backward pass), and not the dtype of the parameters or optimiser states.
|
318 |
+
4. `freeze_decoder`: whether to freeze the student model's decoder. Note that the input tokens embeddings and language modelling head will remain trainable.
|
319 |
+
|
320 |
+
And finally, a few noteworthy training arguments:
|
321 |
+
1. `max_steps`: defines the total number of optimisation steps (forward + backward pass) during training. To reach convergence, you should use a dataset of at least 1k hours and train for a minimum of 50k steps.
|
322 |
+
2. `lr_scheduler_stype`: defines the learning rate schedule, one of `constant_with_warmup` or `linear`. When experimenting with a training set-up or training for very few steps (< 5k), using `constant_with_warmup` is typically beneficial, since the learning rate remains high over the short training run. When performing long training runs (> 5k), using a `linear` schedule generally results in superior downstream performance of the distilled model.
|
323 |
+
|
324 |
+
TODO:
|
325 |
+
- [ ] Template for model cards
|
326 |
+
|
327 |
+
## 4. Evaluation
|
328 |
+
|
329 |
+
There are four types of evaluation performed in Distil-Whisper:
|
330 |
+
1. Short form: evaluation on audio samples less than 30s in duration. Examples include typical ASR test sets, such as the LibriSpeech validation set.
|
331 |
+
2. Sequential long form: evaluation on audio samples longer than 30s in duration using the original "sequential" long-form algorithm. Examples include entire TED talks or earnings calls.
|
332 |
+
3. Chunked long form: evaluation on audio samples longer than 30s in duration using the Transformers "chunked" long-form algorithm.
|
333 |
+
4. Speculative decoding: evaluation on audio samples less than 30s in duration, where a faster, distilled model is used as the assistant to a slower, teacher model.
|
334 |
+
|
335 |
+
All four forms of evaluation are performed using the script [`run_eval.py`](run_eval.py). Unlike the pseudo-labelling
|
336 |
+
and training scripts, the evaluation script assumes that only one GPU accelerator is used. We can copy the corresponding
|
337 |
+
evaluation script to the model repository using the following command:
|
338 |
+
|
339 |
+
```bash
|
340 |
+
cp ../distil-whisper/training/run_eval.py .
|
341 |
+
```
|
342 |
+
|
343 |
+
Models are assessed jointly using:
|
344 |
+
1. The *word-error rate (WER)* metric: measures the number of substitution, deletion and insertion errors relative to the total number of words. A lower WER indicates a more accurate model.
|
345 |
+
2. The *inverse real-time factor (RTFx)* metric: measures the ratio of `audio input time : model compute time`. A higher RTFx indicates a faster model. Note that this metric is WER-dependent, meaning that it makes sense to compare two models' *RTFx* only at fixed *WER* performances. Indeed, deletions could lead to early stopping of token generation, resulting in higher *WER* and lower *RTFx*.
|
346 |
+
3. Token generation speed: This refers to the number of tokens generated per second. As with *RTFx*, this metric is dependent on the *WER* since token generation time is not linear. By default, this metric is calculated by averaging the total number of `generated tokens : generation time` (full forward pass of the model) when evaluating on the given test set. However, using the `--precise_tok_generation` flag will compute this metric separately for a fixed number of tokens.
|
347 |
+
|
348 |
+
In all cases, it is particularly important to evaluate the final model on data that is *out-of-distribution (OOD)* with
|
349 |
+
the training data. Evaluating on OOD data provides insight as to how well the distilled model is likely to generalise to
|
350 |
+
different audio distributions at inference time. In our example, the Common Voice test set is *in-distribution (ID)*
|
351 |
+
with our training data, since it is taken from the same distribution as the Common Voice training set. Whereas the FLEURS
|
352 |
+
test set is OOD, since it is not used as part of the training set. See [Datasets](#1-datasets) section for recommendations.
|
353 |
+
|
354 |
+
### Short Form
|
355 |
+
|
356 |
+
The script [`run_eval.py`](run_eval.py) can be used to evaluate a trained student model over multiple short-form
|
357 |
+
validation sets. The following example demonstrates how to evaluate the student model trained in the previous step on
|
358 |
+
the Common Voice `test` set (ID) and also the FLEURS `test` set (OOD). Again, it leverages streaming mode to bypass
|
359 |
+
the need to download the data offline:
|
360 |
+
|
361 |
+
```bash
|
362 |
+
#!/usr/bin/env bash
|
363 |
+
|
364 |
+
python run_eval.py \
|
365 |
+
--model_name_or_path "./" \
|
366 |
+
--dataset_name "../common_voice_16_1_hi_pseudo_labelled+google/fleurs" \
|
367 |
+
--dataset_config_name "default+hi_in" \
|
368 |
+
--dataset_split_name "test+test" \
|
369 |
+
--text_column_name "sentence+transcription" \
|
370 |
+
--batch_size 16 \
|
371 |
+
--dtype "bfloat16" \
|
372 |
+
--generation_max_length 256 \
|
373 |
+
--language "hi" \
|
374 |
+
--attn_implementation "sdpa" \
|
375 |
+
--streaming
|
376 |
+
|
377 |
+
```
|
378 |
+
|
379 |
+
The student model achieves an average WER of TODO% with an RTFx of TODO for a batch size of 16. We can easily adapt the above
|
380 |
+
script to evaluate the teacher model, simply by switching the `model_name_or_path` to `openai/whisper-large-v3`, which
|
381 |
+
achieves an average WER of TODO% with an RTFx of TODO. Therefore, for a batch size of 16, the student model is a factor of TODO
|
382 |
+
times faster than the teacher. The WER gap can be closed by training on more data (at least 1k hours) for more training
|
383 |
+
steps (at least 50k).
|
384 |
+
|
385 |
+
### Sequential Long Form
|
386 |
+
|
387 |
+
The original Whisper paper presents a long-form transcription algorithm that sequentially transcribes 30-second segments
|
388 |
+
of audio and shifts the sliding window according to the timestamps predicted by the model. This style of sequential
|
389 |
+
inference is performed directly using the [`.generate`](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate)
|
390 |
+
method in Transformers.
|
391 |
+
|
392 |
+
The script [`run_eval.py`](run_eval.py) can be used to evaluate the trained student model on an arbitrary number of
|
393 |
+
long-form evaluation sets using the sequential algorithm. Since we don't have a long-form validation set for Hindi to hand,
|
394 |
+
in this example we'll evaluate the official Distil-Whisper model [`distil-large-v3`](https://huggingface.co/distil-whisper/distil-large-v3)
|
395 |
+
on the TED-LIUM validation set:
|
396 |
+
|
397 |
+
```bash
|
398 |
+
#!/usr/bin/env bash
|
399 |
+
|
400 |
+
accelerate launch run_eval.py \
|
401 |
+
--model_name_or_path "distil-whisper/distil-large-v3" \
|
402 |
+
--dataset_name "distil-whisper/tedlium-long-form" \
|
403 |
+
--dataset_config_name "default" \
|
404 |
+
--dataset_split_name "validation" \
|
405 |
+
--text_column_name "text" \
|
406 |
+
--batch_size 16 \
|
407 |
+
--dtype "bfloat16" \
|
408 |
+
--generation_max_length 256 \
|
409 |
+
--language "en" \
|
410 |
+
--attn_implementation "sdpa" \
|
411 |
+
--streaming
|
412 |
+
|
413 |
+
```
|
414 |
+
|
415 |
+
### Chunked Long Form
|
416 |
+
|
417 |
+
Chunked long form evaluation runs on the premise that a single long audio file can be *chunked* into smaller segments and
|
418 |
+
inferred in parallel. The resulting transcriptions are then joined at the boundaries to give the final text prediction.
|
419 |
+
A small overlap (or *stride*) is used between adjacent segments to ensure a continuous transcription across chunks.
|
420 |
+
|
421 |
+
This style of chunked inference is performed using the [`pipeline`](https://huggingface.co/docs/transformers/main_classes/pipelines)
|
422 |
+
class, which provides a wrapper around the [`.generate`](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate)
|
423 |
+
function for long-form inference.
|
424 |
+
|
425 |
+
The script [`run_eval.py`](run_eval.py) can be used to evaluate the trained student model on an arbitrary number of
|
426 |
+
long-form evaluation sets using the pipeline class. Again, in this example we'll evaluate distil-large-v3 on the
|
427 |
+
TED-LIUM validation set:
|
428 |
+
|
429 |
+
```bash
|
430 |
+
#!/usr/bin/env bash
|
431 |
+
|
432 |
+
python run_eval.py \
|
433 |
+
--model_name_or_path "openai/whisper-large-v3" \
|
434 |
+
--dataset_name "distil-whisper/tedlium-long-form" \
|
435 |
+
--dataset_config_name "default" \
|
436 |
+
--dataset_split_name "validation" \
|
437 |
+
--text_column_name "text" \
|
438 |
+
--use_pipeline \
|
439 |
+
--chunk_length_s 25.0 \
|
440 |
+
--language "en" \
|
441 |
+
--return_timestamps \
|
442 |
+
--dtype "bfloat16" \
|
443 |
+
--streaming
|
444 |
+
|
445 |
+
```
|
446 |
+
|
447 |
+
The argument `chunk_length_s` controls the length of the chunked audio samples. It should be set to match the typical
|
448 |
+
length of audio the student model was trained on. If unsure about what value of `chunk_length_s` is optimal for your case,
|
449 |
+
it is recommended to run a *sweep* over all possible values. A template script for running a [WandB sweep](https://docs.wandb.ai/guides/sweeps)
|
450 |
+
can be found under [`run_chunk_length_s_sweep.yaml`](flax/long_form_transcription_scripts/run_chunk_length_s_sweep.yaml).
|
451 |
+
|
452 |
+
### Speculative Decoding
|
453 |
+
|
454 |
+
Speculative decoding, or assisted generation, relies on the premise that a faster, assistant model can be used to speed-up
|
455 |
+
the generation of a slower, assistant model. Speculative decoding mathematically ensures that exactly the same outputs as
|
456 |
+
Whisper are obtained, while being ~2 times faster. This makes it the perfect drop-in replacement for existing Whisper
|
457 |
+
pipelines, since exactly the same outputs are guaranteed.
|
458 |
+
|
459 |
+
Distil-Whisper checkpoints can be designed to be efficient assistant models to Whisper for speculative decoding. More precisely,
|
460 |
+
by freezing the encoder during training, the distilled model can share the same encoder weights as Whisper during inference, since
|
461 |
+
the encoder weights are un-changed. In doing so, only the distilled 2-layer decoder has to be loaded in addition to the
|
462 |
+
original Whisper model, which is approximately an 8% increase to the total parameter count, with up to 2x faster inference
|
463 |
+
for low batch sizes. For more details on speculative decoding, the reader is advised to refer to the following blog post:
|
464 |
+
[Speculative Decoding for 2x Faster Whisper Inference](https://huggingface.co/blog/whisper-speculative-decoding).
|
465 |
+
|
466 |
+
In the example below, we use our distilled model as an assistant to the large-v3 teacher model during inference:
|
467 |
+
|
468 |
+
```bash
|
469 |
+
#!/usr/bin/env bash
|
470 |
+
|
471 |
+
python run_eval.py \
|
472 |
+
--model_name_or_path "openai/whisper-large-v3" \
|
473 |
+
--assistant_model_name_or_path "./" \
|
474 |
+
--dataset_name "../common_voice_16_1_hi_pseudo_labelled+google/fleurs" \
|
475 |
+
--dataset_config_name "default+hi_in" \
|
476 |
+
--dataset_split_name "test+test" \
|
477 |
+
--text_column_name "sentence+transcription" \
|
478 |
+
--batch_size 16 \
|
479 |
+
--dtype "bfloat16" \
|
480 |
+
--generation_max_length 256 \
|
481 |
+
--language "hi" \
|
482 |
+
--attn_implementation "sdpa" \
|
483 |
+
--streaming
|
484 |
+
|
485 |
+
```
|
486 |
+
|
487 |
+
We see that we achieve a WER of TODO%, the same as what we obtained with the large-v3 model, but with an RTFx of TODO,
|
488 |
+
a factor of TODO faster than using the large-v3 model alone. The RTFx value can be improved by training the student on
|
489 |
+
more data and for more training steps, since this will improve the number of predicted tokens that match the teacher
|
490 |
+
predictions.
|
491 |
+
|
492 |
+
## Recommendations and guidelines
|
493 |
+
|
494 |
+
### 1. Datasets
|
495 |
+
|
496 |
+
As explained, ideally, you should aim for ~1000 hours of audio data for training a distilled model via KD. Moreover, you should evaluate your model on out-of-distribution test sets to assess generalization capacities. With at least 1500 hours of audio data for German, Dutch, French and Spanish, 600 hours for Italian, and 300 hours for Portuguese and Polish (which can be supplemented with your own datasets), a good setup to start with is:
|
497 |
+
- **Training datasets:** [Common Voice 17](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0) and [Multilingual Librispeech](https://huggingface.co/datasets/facebook/multilingual_librispeech). Use the `train` split for training, and the `validation` and `test` splits for in-distribution testing.
|
498 |
+
- **Test datasets:** [VoxPopuli](https://huggingface.co/datasets/facebook/voxpopuli) and [Fleurs](https://huggingface.co/datasets/google/fleurs). Use the `validation` and `test` splits for out-of-distribution testing.
|
499 |
+
|
500 |
+
### 2. Student model's decoder
|
501 |
+
#### 2.1 Number of Decoder Layers
|
502 |
+
|
503 |
+
We recommend using a 2-layers decoder (see language transfer below). However, you can adjust the number of decoder layers when initializing the student model to balance between inference speed and accuracy. Experimentation has revealed that the Pareto optimal points are with 2, 3, and 4-layers decoders. For indicative results, after 10,000 training steps and inference on an 80GB Nvidia H100 with a batch size of 16 and 20 tokens generation, compared to [Whiper *large-v3*](https://huggingface.co/openai/whisper-large-v3) baseline:
|
504 |
+
|
505 |
+
<center>
|
506 |
+
|
507 |
+
| | rel. token gen. speed | ΔWER(%) |
|
508 |
+
|----------|:-------------:|------:|
|
509 |
+
| 2 layers | $3.66$ | $-3.5$ |
|
510 |
+
| 3 layers | $3.35$ | $-2.3$ |
|
511 |
+
| 4 layers | $3.11$ | $-1.8$ |
|
512 |
+
|
513 |
+
</center>
|
514 |
+
|
515 |
+
|
516 |
+
#### 2.2 Language Transfer
|
517 |
+
|
518 |
+
If you opt for a 2-layers decoder, consider leveraging language transfer by initializing the student model from the [distil-large-v3 English distilled model](https://huggingface.co/distil-whisper/distil-large-v3). For French, this method has shown performance improvements of ΔWER=-1.9% (compared to a 2-layers decoder initialized from [Whiper *large-v3*](https://huggingface.co/openai/whisper-large-v3)) after 10,000 training steps.
|
519 |
+
|
520 |
+
```diff
|
521 |
+
- --teacher_checkpoint "openai/whisper-large-v3" \
|
522 |
+
+ --teacher_checkpoint "distil-whisper/distil-large-v3" \
|
523 |
+
```
|
524 |
+
|
525 |
+
### 3. Language mixing
|
526 |
+
|
527 |
+
If you're working with low-resource languages (<500 hours of audio data), consider mixing your training data with a closely related language (for example, mix French and Spanish) to leverage knowledge transfer between languages. Experiments showed that mixing ~400 hours of French (which resulted in a model with poor generalization capacities) with ~500 hours of Spanish improved the model's out-of-distribution performance on French by ΔWER=-7.5%.
|
528 |
+
|
529 |
+
To do this:
|
530 |
+
1. Run [pseudo labeling](#1-pseudo-labelling) for each training dataset, setting the `--language` flag to the language of the respective dataset. In the example of mixing French and Spanish, simply modify the given [pseudo labeling](#1-pseudo-labelling) command with:
|
531 |
+
* pseudo labelling the French dataset
|
532 |
+
```diff
|
533 |
+
- --dataset_config_name "hi" \
|
534 |
+
- --output_dir "./common_voice_16_1_hi_pseudo_labelled" \
|
535 |
+
- --language "hi" \
|
536 |
+
+ --dataset_config_name "fr" \
|
537 |
+
+ --output_dir "./common_voice_16_1_fr_pseudo_labelled" \
|
538 |
+
+ --language "fr" \
|
539 |
+
```
|
540 |
+
* pseudo labelling the Spanish dataset
|
541 |
+
```diff
|
542 |
+
- --dataset_config_name "hi" \
|
543 |
+
- --output_dir "./common_voice_16_1_hi_pseudo_labelled" \
|
544 |
+
- --language "hi" \
|
545 |
+
+ --dataset_config_name "es" \
|
546 |
+
+ --output_dir "./common_voice_16_1_es_pseudo_labelled" \
|
547 |
+
+ --language "es" \
|
548 |
+
```
|
549 |
+
|
550 |
+
2. Conduct [training](#3-training) on these pseudo-labeled datasets, using the `--language` flag set to your targeted language. Note that this flag is only used for evaluation purposes, so you set it to the targeted language. The language token used for forwarding the teacher and student model decoders is the one used and saved in pseudo labels during pseudo-labeling, ensuring it's the correct one for the considered sample. In the example of mixing French and Spanish, simply modify the given [training](#1-pseudo-labelling) command with:
|
551 |
+
```diff
|
552 |
+
- --train_dataset_name "../common_voice_16_1_hi_pseudo_labelled+../common_voice_16_1_hi_pseudo_labelled" \
|
553 |
+
- --train_split_name "train+validation" \
|
554 |
+
- --eval_dataset_name "../common_voice_16_1_hi_pseudo_labelled" \
|
555 |
+
- --eval_split_name "test" \
|
556 |
+
+ --train_dataset_name "../common_voice_17_0_fr_pseudo_labelled+../common_voice_17_0_es_pseudo_labelled" \
|
557 |
+
+ --train_split_name "train+train" \
|
558 |
+
+ --eval_dataset_name "../common_voice_16_1_fr_pseudo_labelled" \
|
559 |
+
+ --eval_split_name "validation" \
|
560 |
+
```
|
561 |
+
|
562 |
+
## Overview of Training Methods
|
563 |
+
|
564 |
+
### 1. Fine-Tuning
|
565 |
+
|
566 |
+
For fine-tuning, we take the original Whisper checkpoint and train it on one or more datasets using the standard
|
567 |
+
cross-entropy loss. As such, there is no involvement from the teacher checkpoint during training, and so the fine-tuned
|
568 |
+
model is permitted to *overfit* to the distribution of the training data we provide. This makes it appealing for "low-resource"
|
569 |
+
languages where the original Whisper model performs poorly, since we can boost the performance of the model on a single
|
570 |
+
language by *overfitting* to that distribution of data. Note that this means the fine-tuned model is prone to loosing
|
571 |
+
its robustness to different audio distributions, which is the trade-off with improving performance on a specified dataset.
|
572 |
+
|
573 |
+
As a rule of thumb, fine-tuning is appropriate for languages where the original Whisper model performs > 20% WER, and we
|
574 |
+
have a relatively small quantity of training data available (< 1000 hours). With fine-tuning, we require as little as **10 hours**
|
575 |
+
of training data to significantly boost the performance of the Whisper model. For an in-depth guide to fine-tuning Whisper,
|
576 |
+
the reader is advised to refer to the blog post: [Fine-Tune Whisper For Multilingual ASR with 🤗 Transformers](https://huggingface.co/blog/fine-tune-whisper).
|
577 |
+
|
578 |
+
### 2. Shrink and Fine-Tune
|
579 |
+
|
580 |
+
Shrink and fine-tune (SFT) is a knowledge distillation (KD) technique in which we first *shrink* the teacher model to a
|
581 |
+
smaller student model by copying maximally spaced layers, and then *fine-tune* the student model on the cross-entropy loss
|
582 |
+
as described above. Typically, we retain the full encoder from the Whisper model and only shrink the decoder. Retaining
|
583 |
+
the entire encoder helps significantly with maintaining Whisper's robustness to different audio distributions (_c.f._
|
584 |
+
Section 9.3 of the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430)).
|
585 |
+
|
586 |
+
We can either train the student model on a dataset of (audio, text) pairs as above. Or, we can use the pre-trained
|
587 |
+
Whisper model to generate *pseudo-labels* for our audio data, and train on the (audio, pseudo-label) pairs.
|
588 |
+
|
589 |
+
Pseudo-labels can be used when either:
|
590 |
+
1. The original text transcriptions are normalised (lower-cased or no punctuation): the Whisper generated pseudo-labels contain both punctuation and casing, and so can be used as a substitute for the normalised transcriptions
|
591 |
+
2. The pre-trained Whisper model achieves < 20% WER on the languages: we then know the majority of the pseudo-labels will be accurate enough for us to train on.
|
592 |
+
|
593 |
+
They are not recommended when both of the following are true:
|
594 |
+
1. The original text is punctuated and cased
|
595 |
+
2. The pre-trained Whisper model achieves > 20% WER on the languages: in this case, we want to overfit to the particular distribution of the language, and so train directly on the original text data
|
596 |
+
|
597 |
+
To discard inaccurate pseudo-labels during training, we employ a simple WER heuristic to filter our pseudo-labelled
|
598 |
+
training data. We first normalise the original text and the pseudo-labelled text using the Whisper normaliser. If the
|
599 |
+
WER between the normalised text exceeds a 10% WER threshold, we discard the training sample. Else, we retain it for training.
|
600 |
+
Section 9.1 of the Distil-Whisper [paper](https://arxiv.org/abs/2311.00430) demonstrates the importance of using this
|
601 |
+
threshold for training.
|
602 |
+
|
603 |
+
### 3. KL Divergence
|
604 |
+
|
605 |
+
In the KL Divergence setting, the student model is initialised by shrinking the teacher as before, and then trained to
|
606 |
+
match the predictions of the teacher during training.
|
607 |
+
|
608 |
+
### Summary of Methods
|
609 |
+
|
610 |
+
The following table summarises the two training paradigms: fine-tuning and knowledge distillation (KD). It suggests
|
611 |
+
minimum values for the pre-trained WER / training data to achieve reasonable performance:
|
612 |
+
|
613 |
+
| Method | Pre-Trained WER / % | Training Data / h |
|
614 |
+
|-------------|---------------------|-------------------|
|
615 |
+
| Fine-tuning | > 20 | < 1000 |
|
616 |
+
| KD | < 20 | > 1000 |
|
617 |
+
|
618 |
+
## Acknowledgements
|
619 |
+
|
620 |
+
* OpenAI for the Whisper [model](https://huggingface.co/openai/whisper-large-v3) and [original codebase](https://github.com/openai/whisper)
|
621 |
+
* Hugging Face 🤗 [Transformers](https://github.com/huggingface/transformers) for the Whisper model implementation
|
622 |
+
* Google's [TPU Research Cloud (TRC)](https://sites.research.google/trc/about/) program for Cloud TPU v4s used to train the official Distil-Whisper models
|
623 |
+
* The Hugging Face 🤗 cluster for enabling experimentation with the PyTorch scripts
|
624 |
+
|
625 |
+
## Citation
|
626 |
+
|
627 |
+
If you use this code-base, please consider citing the Distil-Whisper paper:
|
628 |
+
|
629 |
+
```
|
630 |
+
@misc{gandhi2023distilwhisper,
|
631 |
+
title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling},
|
632 |
+
author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
|
633 |
+
year={2023},
|
634 |
+
eprint={2311.00430},
|
635 |
+
archivePrefix={arXiv},
|
636 |
+
primaryClass={cs.CL}
|
637 |
+
}
|
638 |
+
```
|
added_tokens.json
ADDED
@@ -0,0 +1,1611 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"<|0.00|>": 50365,
|
3 |
+
"<|0.02|>": 50366,
|
4 |
+
"<|0.04|>": 50367,
|
5 |
+
"<|0.06|>": 50368,
|
6 |
+
"<|0.08|>": 50369,
|
7 |
+
"<|0.10|>": 50370,
|
8 |
+
"<|0.12|>": 50371,
|
9 |
+
"<|0.14|>": 50372,
|
10 |
+
"<|0.16|>": 50373,
|
11 |
+
"<|0.18|>": 50374,
|
12 |
+
"<|0.20|>": 50375,
|
13 |
+
"<|0.22|>": 50376,
|
14 |
+
"<|0.24|>": 50377,
|
15 |
+
"<|0.26|>": 50378,
|
16 |
+
"<|0.28|>": 50379,
|
17 |
+
"<|0.30|>": 50380,
|
18 |
+
"<|0.32|>": 50381,
|
19 |
+
"<|0.34|>": 50382,
|
20 |
+
"<|0.36|>": 50383,
|
21 |
+
"<|0.38|>": 50384,
|
22 |
+
"<|0.40|>": 50385,
|
23 |
+
"<|0.42|>": 50386,
|
24 |
+
"<|0.44|>": 50387,
|
25 |
+
"<|0.46|>": 50388,
|
26 |
+
"<|0.48|>": 50389,
|
27 |
+
"<|0.50|>": 50390,
|
28 |
+
"<|0.52|>": 50391,
|
29 |
+
"<|0.54|>": 50392,
|
30 |
+
"<|0.56|>": 50393,
|
31 |
+
"<|0.58|>": 50394,
|
32 |
+
"<|0.60|>": 50395,
|
33 |
+
"<|0.62|>": 50396,
|
34 |
+
"<|0.64|>": 50397,
|
35 |
+
"<|0.66|>": 50398,
|
36 |
+
"<|0.68|>": 50399,
|
37 |
+
"<|0.70|>": 50400,
|
38 |
+
"<|0.72|>": 50401,
|
39 |
+
"<|0.74|>": 50402,
|
40 |
+
"<|0.76|>": 50403,
|
41 |
+
"<|0.78|>": 50404,
|
42 |
+
"<|0.80|>": 50405,
|
43 |
+
"<|0.82|>": 50406,
|
44 |
+
"<|0.84|>": 50407,
|
45 |
+
"<|0.86|>": 50408,
|
46 |
+
"<|0.88|>": 50409,
|
47 |
+
"<|0.90|>": 50410,
|
48 |
+
"<|0.92|>": 50411,
|
49 |
+
"<|0.94|>": 50412,
|
50 |
+
"<|0.96|>": 50413,
|
51 |
+
"<|0.98|>": 50414,
|
52 |
+
"<|1.00|>": 50415,
|
53 |
+
"<|1.02|>": 50416,
|
54 |
+
"<|1.04|>": 50417,
|
55 |
+
"<|1.06|>": 50418,
|
56 |
+
"<|1.08|>": 50419,
|
57 |
+
"<|1.10|>": 50420,
|
58 |
+
"<|1.12|>": 50421,
|
59 |
+
"<|1.14|>": 50422,
|
60 |
+
"<|1.16|>": 50423,
|
61 |
+
"<|1.18|>": 50424,
|
62 |
+
"<|1.20|>": 50425,
|
63 |
+
"<|1.22|>": 50426,
|
64 |
+
"<|1.24|>": 50427,
|
65 |
+
"<|1.26|>": 50428,
|
66 |
+
"<|1.28|>": 50429,
|
67 |
+
"<|1.30|>": 50430,
|
68 |
+
"<|1.32|>": 50431,
|
69 |
+
"<|1.34|>": 50432,
|
70 |
+
"<|1.36|>": 50433,
|
71 |
+
"<|1.38|>": 50434,
|
72 |
+
"<|1.40|>": 50435,
|
73 |
+
"<|1.42|>": 50436,
|
74 |
+
"<|1.44|>": 50437,
|
75 |
+
"<|1.46|>": 50438,
|
76 |
+
"<|1.48|>": 50439,
|
77 |
+
"<|1.50|>": 50440,
|
78 |
+
"<|1.52|>": 50441,
|
79 |
+
"<|1.54|>": 50442,
|
80 |
+
"<|1.56|>": 50443,
|
81 |
+
"<|1.58|>": 50444,
|
82 |
+
"<|1.60|>": 50445,
|
83 |
+
"<|1.62|>": 50446,
|
84 |
+
"<|1.64|>": 50447,
|
85 |
+
"<|1.66|>": 50448,
|
86 |
+
"<|1.68|>": 50449,
|
87 |
+
"<|1.70|>": 50450,
|
88 |
+
"<|1.72|>": 50451,
|
89 |
+
"<|1.74|>": 50452,
|
90 |
+
"<|1.76|>": 50453,
|
91 |
+
"<|1.78|>": 50454,
|
92 |
+
"<|1.80|>": 50455,
|
93 |
+
"<|1.82|>": 50456,
|
94 |
+
"<|1.84|>": 50457,
|
95 |
+
"<|1.86|>": 50458,
|
96 |
+
"<|1.88|>": 50459,
|
97 |
+
"<|1.90|>": 50460,
|
98 |
+
"<|1.92|>": 50461,
|
99 |
+
"<|1.94|>": 50462,
|
100 |
+
"<|1.96|>": 50463,
|
101 |
+
"<|1.98|>": 50464,
|
102 |
+
"<|10.00|>": 50865,
|
103 |
+
"<|10.02|>": 50866,
|
104 |
+
"<|10.04|>": 50867,
|
105 |
+
"<|10.06|>": 50868,
|
106 |
+
"<|10.08|>": 50869,
|
107 |
+
"<|10.10|>": 50870,
|
108 |
+
"<|10.12|>": 50871,
|
109 |
+
"<|10.14|>": 50872,
|
110 |
+
"<|10.16|>": 50873,
|
111 |
+
"<|10.18|>": 50874,
|
112 |
+
"<|10.20|>": 50875,
|
113 |
+
"<|10.22|>": 50876,
|
114 |
+
"<|10.24|>": 50877,
|
115 |
+
"<|10.26|>": 50878,
|
116 |
+
"<|10.28|>": 50879,
|
117 |
+
"<|10.30|>": 50880,
|
118 |
+
"<|10.32|>": 50881,
|
119 |
+
"<|10.34|>": 50882,
|
120 |
+
"<|10.36|>": 50883,
|
121 |
+
"<|10.38|>": 50884,
|
122 |
+
"<|10.40|>": 50885,
|
123 |
+
"<|10.42|>": 50886,
|
124 |
+
"<|10.44|>": 50887,
|
125 |
+
"<|10.46|>": 50888,
|
126 |
+
"<|10.48|>": 50889,
|
127 |
+
"<|10.50|>": 50890,
|
128 |
+
"<|10.52|>": 50891,
|
129 |
+
"<|10.54|>": 50892,
|
130 |
+
"<|10.56|>": 50893,
|
131 |
+
"<|10.58|>": 50894,
|
132 |
+
"<|10.60|>": 50895,
|
133 |
+
"<|10.62|>": 50896,
|
134 |
+
"<|10.64|>": 50897,
|
135 |
+
"<|10.66|>": 50898,
|
136 |
+
"<|10.68|>": 50899,
|
137 |
+
"<|10.70|>": 50900,
|
138 |
+
"<|10.72|>": 50901,
|
139 |
+
"<|10.74|>": 50902,
|
140 |
+
"<|10.76|>": 50903,
|
141 |
+
"<|10.78|>": 50904,
|
142 |
+
"<|10.80|>": 50905,
|
143 |
+
"<|10.82|>": 50906,
|
144 |
+
"<|10.84|>": 50907,
|
145 |
+
"<|10.86|>": 50908,
|
146 |
+
"<|10.88|>": 50909,
|
147 |
+
"<|10.90|>": 50910,
|
148 |
+
"<|10.92|>": 50911,
|
149 |
+
"<|10.94|>": 50912,
|
150 |
+
"<|10.96|>": 50913,
|
151 |
+
"<|10.98|>": 50914,
|
152 |
+
"<|11.00|>": 50915,
|
153 |
+
"<|11.02|>": 50916,
|
154 |
+
"<|11.04|>": 50917,
|
155 |
+
"<|11.06|>": 50918,
|
156 |
+
"<|11.08|>": 50919,
|
157 |
+
"<|11.10|>": 50920,
|
158 |
+
"<|11.12|>": 50921,
|
159 |
+
"<|11.14|>": 50922,
|
160 |
+
"<|11.16|>": 50923,
|
161 |
+
"<|11.18|>": 50924,
|
162 |
+
"<|11.20|>": 50925,
|
163 |
+
"<|11.22|>": 50926,
|
164 |
+
"<|11.24|>": 50927,
|
165 |
+
"<|11.26|>": 50928,
|
166 |
+
"<|11.28|>": 50929,
|
167 |
+
"<|11.30|>": 50930,
|
168 |
+
"<|11.32|>": 50931,
|
169 |
+
"<|11.34|>": 50932,
|
170 |
+
"<|11.36|>": 50933,
|
171 |
+
"<|11.38|>": 50934,
|
172 |
+
"<|11.40|>": 50935,
|
173 |
+
"<|11.42|>": 50936,
|
174 |
+
"<|11.44|>": 50937,
|
175 |
+
"<|11.46|>": 50938,
|
176 |
+
"<|11.48|>": 50939,
|
177 |
+
"<|11.50|>": 50940,
|
178 |
+
"<|11.52|>": 50941,
|
179 |
+
"<|11.54|>": 50942,
|
180 |
+
"<|11.56|>": 50943,
|
181 |
+
"<|11.58|>": 50944,
|
182 |
+
"<|11.60|>": 50945,
|
183 |
+
"<|11.62|>": 50946,
|
184 |
+
"<|11.64|>": 50947,
|
185 |
+
"<|11.66|>": 50948,
|
186 |
+
"<|11.68|>": 50949,
|
187 |
+
"<|11.70|>": 50950,
|
188 |
+
"<|11.72|>": 50951,
|
189 |
+
"<|11.74|>": 50952,
|
190 |
+
"<|11.76|>": 50953,
|
191 |
+
"<|11.78|>": 50954,
|
192 |
+
"<|11.80|>": 50955,
|
193 |
+
"<|11.82|>": 50956,
|
194 |
+
"<|11.84|>": 50957,
|
195 |
+
"<|11.86|>": 50958,
|
196 |
+
"<|11.88|>": 50959,
|
197 |
+
"<|11.90|>": 50960,
|
198 |
+
"<|11.92|>": 50961,
|
199 |
+
"<|11.94|>": 50962,
|
200 |
+
"<|11.96|>": 50963,
|
201 |
+
"<|11.98|>": 50964,
|
202 |
+
"<|12.00|>": 50965,
|
203 |
+
"<|12.02|>": 50966,
|
204 |
+
"<|12.04|>": 50967,
|
205 |
+
"<|12.06|>": 50968,
|
206 |
+
"<|12.08|>": 50969,
|
207 |
+
"<|12.10|>": 50970,
|
208 |
+
"<|12.12|>": 50971,
|
209 |
+
"<|12.14|>": 50972,
|
210 |
+
"<|12.16|>": 50973,
|
211 |
+
"<|12.18|>": 50974,
|
212 |
+
"<|12.20|>": 50975,
|
213 |
+
"<|12.22|>": 50976,
|
214 |
+
"<|12.24|>": 50977,
|
215 |
+
"<|12.26|>": 50978,
|
216 |
+
"<|12.28|>": 50979,
|
217 |
+
"<|12.30|>": 50980,
|
218 |
+
"<|12.32|>": 50981,
|
219 |
+
"<|12.34|>": 50982,
|
220 |
+
"<|12.36|>": 50983,
|
221 |
+
"<|12.38|>": 50984,
|
222 |
+
"<|12.40|>": 50985,
|
223 |
+
"<|12.42|>": 50986,
|
224 |
+
"<|12.44|>": 50987,
|
225 |
+
"<|12.46|>": 50988,
|
226 |
+
"<|12.48|>": 50989,
|
227 |
+
"<|12.50|>": 50990,
|
228 |
+
"<|12.52|>": 50991,
|
229 |
+
"<|12.54|>": 50992,
|
230 |
+
"<|12.56|>": 50993,
|
231 |
+
"<|12.58|>": 50994,
|
232 |
+
"<|12.60|>": 50995,
|
233 |
+
"<|12.62|>": 50996,
|
234 |
+
"<|12.64|>": 50997,
|
235 |
+
"<|12.66|>": 50998,
|
236 |
+
"<|12.68|>": 50999,
|
237 |
+
"<|12.70|>": 51000,
|
238 |
+
"<|12.72|>": 51001,
|
239 |
+
"<|12.74|>": 51002,
|
240 |
+
"<|12.76|>": 51003,
|
241 |
+
"<|12.78|>": 51004,
|
242 |
+
"<|12.80|>": 51005,
|
243 |
+
"<|12.82|>": 51006,
|
244 |
+
"<|12.84|>": 51007,
|
245 |
+
"<|12.86|>": 51008,
|
246 |
+
"<|12.88|>": 51009,
|
247 |
+
"<|12.90|>": 51010,
|
248 |
+
"<|12.92|>": 51011,
|
249 |
+
"<|12.94|>": 51012,
|
250 |
+
"<|12.96|>": 51013,
|
251 |
+
"<|12.98|>": 51014,
|
252 |
+
"<|13.00|>": 51015,
|
253 |
+
"<|13.02|>": 51016,
|
254 |
+
"<|13.04|>": 51017,
|
255 |
+
"<|13.06|>": 51018,
|
256 |
+
"<|13.08|>": 51019,
|
257 |
+
"<|13.10|>": 51020,
|
258 |
+
"<|13.12|>": 51021,
|
259 |
+
"<|13.14|>": 51022,
|
260 |
+
"<|13.16|>": 51023,
|
261 |
+
"<|13.18|>": 51024,
|
262 |
+
"<|13.20|>": 51025,
|
263 |
+
"<|13.22|>": 51026,
|
264 |
+
"<|13.24|>": 51027,
|
265 |
+
"<|13.26|>": 51028,
|
266 |
+
"<|13.28|>": 51029,
|
267 |
+
"<|13.30|>": 51030,
|
268 |
+
"<|13.32|>": 51031,
|
269 |
+
"<|13.34|>": 51032,
|
270 |
+
"<|13.36|>": 51033,
|
271 |
+
"<|13.38|>": 51034,
|
272 |
+
"<|13.40|>": 51035,
|
273 |
+
"<|13.42|>": 51036,
|
274 |
+
"<|13.44|>": 51037,
|
275 |
+
"<|13.46|>": 51038,
|
276 |
+
"<|13.48|>": 51039,
|
277 |
+
"<|13.50|>": 51040,
|
278 |
+
"<|13.52|>": 51041,
|
279 |
+
"<|13.54|>": 51042,
|
280 |
+
"<|13.56|>": 51043,
|
281 |
+
"<|13.58|>": 51044,
|
282 |
+
"<|13.60|>": 51045,
|
283 |
+
"<|13.62|>": 51046,
|
284 |
+
"<|13.64|>": 51047,
|
285 |
+
"<|13.66|>": 51048,
|
286 |
+
"<|13.68|>": 51049,
|
287 |
+
"<|13.70|>": 51050,
|
288 |
+
"<|13.72|>": 51051,
|
289 |
+
"<|13.74|>": 51052,
|
290 |
+
"<|13.76|>": 51053,
|
291 |
+
"<|13.78|>": 51054,
|
292 |
+
"<|13.80|>": 51055,
|
293 |
+
"<|13.82|>": 51056,
|
294 |
+
"<|13.84|>": 51057,
|
295 |
+
"<|13.86|>": 51058,
|
296 |
+
"<|13.88|>": 51059,
|
297 |
+
"<|13.90|>": 51060,
|
298 |
+
"<|13.92|>": 51061,
|
299 |
+
"<|13.94|>": 51062,
|
300 |
+
"<|13.96|>": 51063,
|
301 |
+
"<|13.98|>": 51064,
|
302 |
+
"<|14.00|>": 51065,
|
303 |
+
"<|14.02|>": 51066,
|
304 |
+
"<|14.04|>": 51067,
|
305 |
+
"<|14.06|>": 51068,
|
306 |
+
"<|14.08|>": 51069,
|
307 |
+
"<|14.10|>": 51070,
|
308 |
+
"<|14.12|>": 51071,
|
309 |
+
"<|14.14|>": 51072,
|
310 |
+
"<|14.16|>": 51073,
|
311 |
+
"<|14.18|>": 51074,
|
312 |
+
"<|14.20|>": 51075,
|
313 |
+
"<|14.22|>": 51076,
|
314 |
+
"<|14.24|>": 51077,
|
315 |
+
"<|14.26|>": 51078,
|
316 |
+
"<|14.28|>": 51079,
|
317 |
+
"<|14.30|>": 51080,
|
318 |
+
"<|14.32|>": 51081,
|
319 |
+
"<|14.34|>": 51082,
|
320 |
+
"<|14.36|>": 51083,
|
321 |
+
"<|14.38|>": 51084,
|
322 |
+
"<|14.40|>": 51085,
|
323 |
+
"<|14.42|>": 51086,
|
324 |
+
"<|14.44|>": 51087,
|
325 |
+
"<|14.46|>": 51088,
|
326 |
+
"<|14.48|>": 51089,
|
327 |
+
"<|14.50|>": 51090,
|
328 |
+
"<|14.52|>": 51091,
|
329 |
+
"<|14.54|>": 51092,
|
330 |
+
"<|14.56|>": 51093,
|
331 |
+
"<|14.58|>": 51094,
|
332 |
+
"<|14.60|>": 51095,
|
333 |
+
"<|14.62|>": 51096,
|
334 |
+
"<|14.64|>": 51097,
|
335 |
+
"<|14.66|>": 51098,
|
336 |
+
"<|14.68|>": 51099,
|
337 |
+
"<|14.70|>": 51100,
|
338 |
+
"<|14.72|>": 51101,
|
339 |
+
"<|14.74|>": 51102,
|
340 |
+
"<|14.76|>": 51103,
|
341 |
+
"<|14.78|>": 51104,
|
342 |
+
"<|14.80|>": 51105,
|
343 |
+
"<|14.82|>": 51106,
|
344 |
+
"<|14.84|>": 51107,
|
345 |
+
"<|14.86|>": 51108,
|
346 |
+
"<|14.88|>": 51109,
|
347 |
+
"<|14.90|>": 51110,
|
348 |
+
"<|14.92|>": 51111,
|
349 |
+
"<|14.94|>": 51112,
|
350 |
+
"<|14.96|>": 51113,
|
351 |
+
"<|14.98|>": 51114,
|
352 |
+
"<|15.00|>": 51115,
|
353 |
+
"<|15.02|>": 51116,
|
354 |
+
"<|15.04|>": 51117,
|
355 |
+
"<|15.06|>": 51118,
|
356 |
+
"<|15.08|>": 51119,
|
357 |
+
"<|15.10|>": 51120,
|
358 |
+
"<|15.12|>": 51121,
|
359 |
+
"<|15.14|>": 51122,
|
360 |
+
"<|15.16|>": 51123,
|
361 |
+
"<|15.18|>": 51124,
|
362 |
+
"<|15.20|>": 51125,
|
363 |
+
"<|15.22|>": 51126,
|
364 |
+
"<|15.24|>": 51127,
|
365 |
+
"<|15.26|>": 51128,
|
366 |
+
"<|15.28|>": 51129,
|
367 |
+
"<|15.30|>": 51130,
|
368 |
+
"<|15.32|>": 51131,
|
369 |
+
"<|15.34|>": 51132,
|
370 |
+
"<|15.36|>": 51133,
|
371 |
+
"<|15.38|>": 51134,
|
372 |
+
"<|15.40|>": 51135,
|
373 |
+
"<|15.42|>": 51136,
|
374 |
+
"<|15.44|>": 51137,
|
375 |
+
"<|15.46|>": 51138,
|
376 |
+
"<|15.48|>": 51139,
|
377 |
+
"<|15.50|>": 51140,
|
378 |
+
"<|15.52|>": 51141,
|
379 |
+
"<|15.54|>": 51142,
|
380 |
+
"<|15.56|>": 51143,
|
381 |
+
"<|15.58|>": 51144,
|
382 |
+
"<|15.60|>": 51145,
|
383 |
+
"<|15.62|>": 51146,
|
384 |
+
"<|15.64|>": 51147,
|
385 |
+
"<|15.66|>": 51148,
|
386 |
+
"<|15.68|>": 51149,
|
387 |
+
"<|15.70|>": 51150,
|
388 |
+
"<|15.72|>": 51151,
|
389 |
+
"<|15.74|>": 51152,
|
390 |
+
"<|15.76|>": 51153,
|
391 |
+
"<|15.78|>": 51154,
|
392 |
+
"<|15.80|>": 51155,
|
393 |
+
"<|15.82|>": 51156,
|
394 |
+
"<|15.84|>": 51157,
|
395 |
+
"<|15.86|>": 51158,
|
396 |
+
"<|15.88|>": 51159,
|
397 |
+
"<|15.90|>": 51160,
|
398 |
+
"<|15.92|>": 51161,
|
399 |
+
"<|15.94|>": 51162,
|
400 |
+
"<|15.96|>": 51163,
|
401 |
+
"<|15.98|>": 51164,
|
402 |
+
"<|16.00|>": 51165,
|
403 |
+
"<|16.02|>": 51166,
|
404 |
+
"<|16.04|>": 51167,
|
405 |
+
"<|16.06|>": 51168,
|
406 |
+
"<|16.08|>": 51169,
|
407 |
+
"<|16.10|>": 51170,
|
408 |
+
"<|16.12|>": 51171,
|
409 |
+
"<|16.14|>": 51172,
|
410 |
+
"<|16.16|>": 51173,
|
411 |
+
"<|16.18|>": 51174,
|
412 |
+
"<|16.20|>": 51175,
|
413 |
+
"<|16.22|>": 51176,
|
414 |
+
"<|16.24|>": 51177,
|
415 |
+
"<|16.26|>": 51178,
|
416 |
+
"<|16.28|>": 51179,
|
417 |
+
"<|16.30|>": 51180,
|
418 |
+
"<|16.32|>": 51181,
|
419 |
+
"<|16.34|>": 51182,
|
420 |
+
"<|16.36|>": 51183,
|
421 |
+
"<|16.38|>": 51184,
|
422 |
+
"<|16.40|>": 51185,
|
423 |
+
"<|16.42|>": 51186,
|
424 |
+
"<|16.44|>": 51187,
|
425 |
+
"<|16.46|>": 51188,
|
426 |
+
"<|16.48|>": 51189,
|
427 |
+
"<|16.50|>": 51190,
|
428 |
+
"<|16.52|>": 51191,
|
429 |
+
"<|16.54|>": 51192,
|
430 |
+
"<|16.56|>": 51193,
|
431 |
+
"<|16.58|>": 51194,
|
432 |
+
"<|16.60|>": 51195,
|
433 |
+
"<|16.62|>": 51196,
|
434 |
+
"<|16.64|>": 51197,
|
435 |
+
"<|16.66|>": 51198,
|
436 |
+
"<|16.68|>": 51199,
|
437 |
+
"<|16.70|>": 51200,
|
438 |
+
"<|16.72|>": 51201,
|
439 |
+
"<|16.74|>": 51202,
|
440 |
+
"<|16.76|>": 51203,
|
441 |
+
"<|16.78|>": 51204,
|
442 |
+
"<|16.80|>": 51205,
|
443 |
+
"<|16.82|>": 51206,
|
444 |
+
"<|16.84|>": 51207,
|
445 |
+
"<|16.86|>": 51208,
|
446 |
+
"<|16.88|>": 51209,
|
447 |
+
"<|16.90|>": 51210,
|
448 |
+
"<|16.92|>": 51211,
|
449 |
+
"<|16.94|>": 51212,
|
450 |
+
"<|16.96|>": 51213,
|
451 |
+
"<|16.98|>": 51214,
|
452 |
+
"<|17.00|>": 51215,
|
453 |
+
"<|17.02|>": 51216,
|
454 |
+
"<|17.04|>": 51217,
|
455 |
+
"<|17.06|>": 51218,
|
456 |
+
"<|17.08|>": 51219,
|
457 |
+
"<|17.10|>": 51220,
|
458 |
+
"<|17.12|>": 51221,
|
459 |
+
"<|17.14|>": 51222,
|
460 |
+
"<|17.16|>": 51223,
|
461 |
+
"<|17.18|>": 51224,
|
462 |
+
"<|17.20|>": 51225,
|
463 |
+
"<|17.22|>": 51226,
|
464 |
+
"<|17.24|>": 51227,
|
465 |
+
"<|17.26|>": 51228,
|
466 |
+
"<|17.28|>": 51229,
|
467 |
+
"<|17.30|>": 51230,
|
468 |
+
"<|17.32|>": 51231,
|
469 |
+
"<|17.34|>": 51232,
|
470 |
+
"<|17.36|>": 51233,
|
471 |
+
"<|17.38|>": 51234,
|
472 |
+
"<|17.40|>": 51235,
|
473 |
+
"<|17.42|>": 51236,
|
474 |
+
"<|17.44|>": 51237,
|
475 |
+
"<|17.46|>": 51238,
|
476 |
+
"<|17.48|>": 51239,
|
477 |
+
"<|17.50|>": 51240,
|
478 |
+
"<|17.52|>": 51241,
|
479 |
+
"<|17.54|>": 51242,
|
480 |
+
"<|17.56|>": 51243,
|
481 |
+
"<|17.58|>": 51244,
|
482 |
+
"<|17.60|>": 51245,
|
483 |
+
"<|17.62|>": 51246,
|
484 |
+
"<|17.64|>": 51247,
|
485 |
+
"<|17.66|>": 51248,
|
486 |
+
"<|17.68|>": 51249,
|
487 |
+
"<|17.70|>": 51250,
|
488 |
+
"<|17.72|>": 51251,
|
489 |
+
"<|17.74|>": 51252,
|
490 |
+
"<|17.76|>": 51253,
|
491 |
+
"<|17.78|>": 51254,
|
492 |
+
"<|17.80|>": 51255,
|
493 |
+
"<|17.82|>": 51256,
|
494 |
+
"<|17.84|>": 51257,
|
495 |
+
"<|17.86|>": 51258,
|
496 |
+
"<|17.88|>": 51259,
|
497 |
+
"<|17.90|>": 51260,
|
498 |
+
"<|17.92|>": 51261,
|
499 |
+
"<|17.94|>": 51262,
|
500 |
+
"<|17.96|>": 51263,
|
501 |
+
"<|17.98|>": 51264,
|
502 |
+
"<|18.00|>": 51265,
|
503 |
+
"<|18.02|>": 51266,
|
504 |
+
"<|18.04|>": 51267,
|
505 |
+
"<|18.06|>": 51268,
|
506 |
+
"<|18.08|>": 51269,
|
507 |
+
"<|18.10|>": 51270,
|
508 |
+
"<|18.12|>": 51271,
|
509 |
+
"<|18.14|>": 51272,
|
510 |
+
"<|18.16|>": 51273,
|
511 |
+
"<|18.18|>": 51274,
|
512 |
+
"<|18.20|>": 51275,
|
513 |
+
"<|18.22|>": 51276,
|
514 |
+
"<|18.24|>": 51277,
|
515 |
+
"<|18.26|>": 51278,
|
516 |
+
"<|18.28|>": 51279,
|
517 |
+
"<|18.30|>": 51280,
|
518 |
+
"<|18.32|>": 51281,
|
519 |
+
"<|18.34|>": 51282,
|
520 |
+
"<|18.36|>": 51283,
|
521 |
+
"<|18.38|>": 51284,
|
522 |
+
"<|18.40|>": 51285,
|
523 |
+
"<|18.42|>": 51286,
|
524 |
+
"<|18.44|>": 51287,
|
525 |
+
"<|18.46|>": 51288,
|
526 |
+
"<|18.48|>": 51289,
|
527 |
+
"<|18.50|>": 51290,
|
528 |
+
"<|18.52|>": 51291,
|
529 |
+
"<|18.54|>": 51292,
|
530 |
+
"<|18.56|>": 51293,
|
531 |
+
"<|18.58|>": 51294,
|
532 |
+
"<|18.60|>": 51295,
|
533 |
+
"<|18.62|>": 51296,
|
534 |
+
"<|18.64|>": 51297,
|
535 |
+
"<|18.66|>": 51298,
|
536 |
+
"<|18.68|>": 51299,
|
537 |
+
"<|18.70|>": 51300,
|
538 |
+
"<|18.72|>": 51301,
|
539 |
+
"<|18.74|>": 51302,
|
540 |
+
"<|18.76|>": 51303,
|
541 |
+
"<|18.78|>": 51304,
|
542 |
+
"<|18.80|>": 51305,
|
543 |
+
"<|18.82|>": 51306,
|
544 |
+
"<|18.84|>": 51307,
|
545 |
+
"<|18.86|>": 51308,
|
546 |
+
"<|18.88|>": 51309,
|
547 |
+
"<|18.90|>": 51310,
|
548 |
+
"<|18.92|>": 51311,
|
549 |
+
"<|18.94|>": 51312,
|
550 |
+
"<|18.96|>": 51313,
|
551 |
+
"<|18.98|>": 51314,
|
552 |
+
"<|19.00|>": 51315,
|
553 |
+
"<|19.02|>": 51316,
|
554 |
+
"<|19.04|>": 51317,
|
555 |
+
"<|19.06|>": 51318,
|
556 |
+
"<|19.08|>": 51319,
|
557 |
+
"<|19.10|>": 51320,
|
558 |
+
"<|19.12|>": 51321,
|
559 |
+
"<|19.14|>": 51322,
|
560 |
+
"<|19.16|>": 51323,
|
561 |
+
"<|19.18|>": 51324,
|
562 |
+
"<|19.20|>": 51325,
|
563 |
+
"<|19.22|>": 51326,
|
564 |
+
"<|19.24|>": 51327,
|
565 |
+
"<|19.26|>": 51328,
|
566 |
+
"<|19.28|>": 51329,
|
567 |
+
"<|19.30|>": 51330,
|
568 |
+
"<|19.32|>": 51331,
|
569 |
+
"<|19.34|>": 51332,
|
570 |
+
"<|19.36|>": 51333,
|
571 |
+
"<|19.38|>": 51334,
|
572 |
+
"<|19.40|>": 51335,
|
573 |
+
"<|19.42|>": 51336,
|
574 |
+
"<|19.44|>": 51337,
|
575 |
+
"<|19.46|>": 51338,
|
576 |
+
"<|19.48|>": 51339,
|
577 |
+
"<|19.50|>": 51340,
|
578 |
+
"<|19.52|>": 51341,
|
579 |
+
"<|19.54|>": 51342,
|
580 |
+
"<|19.56|>": 51343,
|
581 |
+
"<|19.58|>": 51344,
|
582 |
+
"<|19.60|>": 51345,
|
583 |
+
"<|19.62|>": 51346,
|
584 |
+
"<|19.64|>": 51347,
|
585 |
+
"<|19.66|>": 51348,
|
586 |
+
"<|19.68|>": 51349,
|
587 |
+
"<|19.70|>": 51350,
|
588 |
+
"<|19.72|>": 51351,
|
589 |
+
"<|19.74|>": 51352,
|
590 |
+
"<|19.76|>": 51353,
|
591 |
+
"<|19.78|>": 51354,
|
592 |
+
"<|19.80|>": 51355,
|
593 |
+
"<|19.82|>": 51356,
|
594 |
+
"<|19.84|>": 51357,
|
595 |
+
"<|19.86|>": 51358,
|
596 |
+
"<|19.88|>": 51359,
|
597 |
+
"<|19.90|>": 51360,
|
598 |
+
"<|19.92|>": 51361,
|
599 |
+
"<|19.94|>": 51362,
|
600 |
+
"<|19.96|>": 51363,
|
601 |
+
"<|19.98|>": 51364,
|
602 |
+
"<|2.00|>": 50465,
|
603 |
+
"<|2.02|>": 50466,
|
604 |
+
"<|2.04|>": 50467,
|
605 |
+
"<|2.06|>": 50468,
|
606 |
+
"<|2.08|>": 50469,
|
607 |
+
"<|2.10|>": 50470,
|
608 |
+
"<|2.12|>": 50471,
|
609 |
+
"<|2.14|>": 50472,
|
610 |
+
"<|2.16|>": 50473,
|
611 |
+
"<|2.18|>": 50474,
|
612 |
+
"<|2.20|>": 50475,
|
613 |
+
"<|2.22|>": 50476,
|
614 |
+
"<|2.24|>": 50477,
|
615 |
+
"<|2.26|>": 50478,
|
616 |
+
"<|2.28|>": 50479,
|
617 |
+
"<|2.30|>": 50480,
|
618 |
+
"<|2.32|>": 50481,
|
619 |
+
"<|2.34|>": 50482,
|
620 |
+
"<|2.36|>": 50483,
|
621 |
+
"<|2.38|>": 50484,
|
622 |
+
"<|2.40|>": 50485,
|
623 |
+
"<|2.42|>": 50486,
|
624 |
+
"<|2.44|>": 50487,
|
625 |
+
"<|2.46|>": 50488,
|
626 |
+
"<|2.48|>": 50489,
|
627 |
+
"<|2.50|>": 50490,
|
628 |
+
"<|2.52|>": 50491,
|
629 |
+
"<|2.54|>": 50492,
|
630 |
+
"<|2.56|>": 50493,
|
631 |
+
"<|2.58|>": 50494,
|
632 |
+
"<|2.60|>": 50495,
|
633 |
+
"<|2.62|>": 50496,
|
634 |
+
"<|2.64|>": 50497,
|
635 |
+
"<|2.66|>": 50498,
|
636 |
+
"<|2.68|>": 50499,
|
637 |
+
"<|2.70|>": 50500,
|
638 |
+
"<|2.72|>": 50501,
|
639 |
+
"<|2.74|>": 50502,
|
640 |
+
"<|2.76|>": 50503,
|
641 |
+
"<|2.78|>": 50504,
|
642 |
+
"<|2.80|>": 50505,
|
643 |
+
"<|2.82|>": 50506,
|
644 |
+
"<|2.84|>": 50507,
|
645 |
+
"<|2.86|>": 50508,
|
646 |
+
"<|2.88|>": 50509,
|
647 |
+
"<|2.90|>": 50510,
|
648 |
+
"<|2.92|>": 50511,
|
649 |
+
"<|2.94|>": 50512,
|
650 |
+
"<|2.96|>": 50513,
|
651 |
+
"<|2.98|>": 50514,
|
652 |
+
"<|20.00|>": 51365,
|
653 |
+
"<|20.02|>": 51366,
|
654 |
+
"<|20.04|>": 51367,
|
655 |
+
"<|20.06|>": 51368,
|
656 |
+
"<|20.08|>": 51369,
|
657 |
+
"<|20.10|>": 51370,
|
658 |
+
"<|20.12|>": 51371,
|
659 |
+
"<|20.14|>": 51372,
|
660 |
+
"<|20.16|>": 51373,
|
661 |
+
"<|20.18|>": 51374,
|
662 |
+
"<|20.20|>": 51375,
|
663 |
+
"<|20.22|>": 51376,
|
664 |
+
"<|20.24|>": 51377,
|
665 |
+
"<|20.26|>": 51378,
|
666 |
+
"<|20.28|>": 51379,
|
667 |
+
"<|20.30|>": 51380,
|
668 |
+
"<|20.32|>": 51381,
|
669 |
+
"<|20.34|>": 51382,
|
670 |
+
"<|20.36|>": 51383,
|
671 |
+
"<|20.38|>": 51384,
|
672 |
+
"<|20.40|>": 51385,
|
673 |
+
"<|20.42|>": 51386,
|
674 |
+
"<|20.44|>": 51387,
|
675 |
+
"<|20.46|>": 51388,
|
676 |
+
"<|20.48|>": 51389,
|
677 |
+
"<|20.50|>": 51390,
|
678 |
+
"<|20.52|>": 51391,
|
679 |
+
"<|20.54|>": 51392,
|
680 |
+
"<|20.56|>": 51393,
|
681 |
+
"<|20.58|>": 51394,
|
682 |
+
"<|20.60|>": 51395,
|
683 |
+
"<|20.62|>": 51396,
|
684 |
+
"<|20.64|>": 51397,
|
685 |
+
"<|20.66|>": 51398,
|
686 |
+
"<|20.68|>": 51399,
|
687 |
+
"<|20.70|>": 51400,
|
688 |
+
"<|20.72|>": 51401,
|
689 |
+
"<|20.74|>": 51402,
|
690 |
+
"<|20.76|>": 51403,
|
691 |
+
"<|20.78|>": 51404,
|
692 |
+
"<|20.80|>": 51405,
|
693 |
+
"<|20.82|>": 51406,
|
694 |
+
"<|20.84|>": 51407,
|
695 |
+
"<|20.86|>": 51408,
|
696 |
+
"<|20.88|>": 51409,
|
697 |
+
"<|20.90|>": 51410,
|
698 |
+
"<|20.92|>": 51411,
|
699 |
+
"<|20.94|>": 51412,
|
700 |
+
"<|20.96|>": 51413,
|
701 |
+
"<|20.98|>": 51414,
|
702 |
+
"<|21.00|>": 51415,
|
703 |
+
"<|21.02|>": 51416,
|
704 |
+
"<|21.04|>": 51417,
|
705 |
+
"<|21.06|>": 51418,
|
706 |
+
"<|21.08|>": 51419,
|
707 |
+
"<|21.10|>": 51420,
|
708 |
+
"<|21.12|>": 51421,
|
709 |
+
"<|21.14|>": 51422,
|
710 |
+
"<|21.16|>": 51423,
|
711 |
+
"<|21.18|>": 51424,
|
712 |
+
"<|21.20|>": 51425,
|
713 |
+
"<|21.22|>": 51426,
|
714 |
+
"<|21.24|>": 51427,
|
715 |
+
"<|21.26|>": 51428,
|
716 |
+
"<|21.28|>": 51429,
|
717 |
+
"<|21.30|>": 51430,
|
718 |
+
"<|21.32|>": 51431,
|
719 |
+
"<|21.34|>": 51432,
|
720 |
+
"<|21.36|>": 51433,
|
721 |
+
"<|21.38|>": 51434,
|
722 |
+
"<|21.40|>": 51435,
|
723 |
+
"<|21.42|>": 51436,
|
724 |
+
"<|21.44|>": 51437,
|
725 |
+
"<|21.46|>": 51438,
|
726 |
+
"<|21.48|>": 51439,
|
727 |
+
"<|21.50|>": 51440,
|
728 |
+
"<|21.52|>": 51441,
|
729 |
+
"<|21.54|>": 51442,
|
730 |
+
"<|21.56|>": 51443,
|
731 |
+
"<|21.58|>": 51444,
|
732 |
+
"<|21.60|>": 51445,
|
733 |
+
"<|21.62|>": 51446,
|
734 |
+
"<|21.64|>": 51447,
|
735 |
+
"<|21.66|>": 51448,
|
736 |
+
"<|21.68|>": 51449,
|
737 |
+
"<|21.70|>": 51450,
|
738 |
+
"<|21.72|>": 51451,
|
739 |
+
"<|21.74|>": 51452,
|
740 |
+
"<|21.76|>": 51453,
|
741 |
+
"<|21.78|>": 51454,
|
742 |
+
"<|21.80|>": 51455,
|
743 |
+
"<|21.82|>": 51456,
|
744 |
+
"<|21.84|>": 51457,
|
745 |
+
"<|21.86|>": 51458,
|
746 |
+
"<|21.88|>": 51459,
|
747 |
+
"<|21.90|>": 51460,
|
748 |
+
"<|21.92|>": 51461,
|
749 |
+
"<|21.94|>": 51462,
|
750 |
+
"<|21.96|>": 51463,
|
751 |
+
"<|21.98|>": 51464,
|
752 |
+
"<|22.00|>": 51465,
|
753 |
+
"<|22.02|>": 51466,
|
754 |
+
"<|22.04|>": 51467,
|
755 |
+
"<|22.06|>": 51468,
|
756 |
+
"<|22.08|>": 51469,
|
757 |
+
"<|22.10|>": 51470,
|
758 |
+
"<|22.12|>": 51471,
|
759 |
+
"<|22.14|>": 51472,
|
760 |
+
"<|22.16|>": 51473,
|
761 |
+
"<|22.18|>": 51474,
|
762 |
+
"<|22.20|>": 51475,
|
763 |
+
"<|22.22|>": 51476,
|
764 |
+
"<|22.24|>": 51477,
|
765 |
+
"<|22.26|>": 51478,
|
766 |
+
"<|22.28|>": 51479,
|
767 |
+
"<|22.30|>": 51480,
|
768 |
+
"<|22.32|>": 51481,
|
769 |
+
"<|22.34|>": 51482,
|
770 |
+
"<|22.36|>": 51483,
|
771 |
+
"<|22.38|>": 51484,
|
772 |
+
"<|22.40|>": 51485,
|
773 |
+
"<|22.42|>": 51486,
|
774 |
+
"<|22.44|>": 51487,
|
775 |
+
"<|22.46|>": 51488,
|
776 |
+
"<|22.48|>": 51489,
|
777 |
+
"<|22.50|>": 51490,
|
778 |
+
"<|22.52|>": 51491,
|
779 |
+
"<|22.54|>": 51492,
|
780 |
+
"<|22.56|>": 51493,
|
781 |
+
"<|22.58|>": 51494,
|
782 |
+
"<|22.60|>": 51495,
|
783 |
+
"<|22.62|>": 51496,
|
784 |
+
"<|22.64|>": 51497,
|
785 |
+
"<|22.66|>": 51498,
|
786 |
+
"<|22.68|>": 51499,
|
787 |
+
"<|22.70|>": 51500,
|
788 |
+
"<|22.72|>": 51501,
|
789 |
+
"<|22.74|>": 51502,
|
790 |
+
"<|22.76|>": 51503,
|
791 |
+
"<|22.78|>": 51504,
|
792 |
+
"<|22.80|>": 51505,
|
793 |
+
"<|22.82|>": 51506,
|
794 |
+
"<|22.84|>": 51507,
|
795 |
+
"<|22.86|>": 51508,
|
796 |
+
"<|22.88|>": 51509,
|
797 |
+
"<|22.90|>": 51510,
|
798 |
+
"<|22.92|>": 51511,
|
799 |
+
"<|22.94|>": 51512,
|
800 |
+
"<|22.96|>": 51513,
|
801 |
+
"<|22.98|>": 51514,
|
802 |
+
"<|23.00|>": 51515,
|
803 |
+
"<|23.02|>": 51516,
|
804 |
+
"<|23.04|>": 51517,
|
805 |
+
"<|23.06|>": 51518,
|
806 |
+
"<|23.08|>": 51519,
|
807 |
+
"<|23.10|>": 51520,
|
808 |
+
"<|23.12|>": 51521,
|
809 |
+
"<|23.14|>": 51522,
|
810 |
+
"<|23.16|>": 51523,
|
811 |
+
"<|23.18|>": 51524,
|
812 |
+
"<|23.20|>": 51525,
|
813 |
+
"<|23.22|>": 51526,
|
814 |
+
"<|23.24|>": 51527,
|
815 |
+
"<|23.26|>": 51528,
|
816 |
+
"<|23.28|>": 51529,
|
817 |
+
"<|23.30|>": 51530,
|
818 |
+
"<|23.32|>": 51531,
|
819 |
+
"<|23.34|>": 51532,
|
820 |
+
"<|23.36|>": 51533,
|
821 |
+
"<|23.38|>": 51534,
|
822 |
+
"<|23.40|>": 51535,
|
823 |
+
"<|23.42|>": 51536,
|
824 |
+
"<|23.44|>": 51537,
|
825 |
+
"<|23.46|>": 51538,
|
826 |
+
"<|23.48|>": 51539,
|
827 |
+
"<|23.50|>": 51540,
|
828 |
+
"<|23.52|>": 51541,
|
829 |
+
"<|23.54|>": 51542,
|
830 |
+
"<|23.56|>": 51543,
|
831 |
+
"<|23.58|>": 51544,
|
832 |
+
"<|23.60|>": 51545,
|
833 |
+
"<|23.62|>": 51546,
|
834 |
+
"<|23.64|>": 51547,
|
835 |
+
"<|23.66|>": 51548,
|
836 |
+
"<|23.68|>": 51549,
|
837 |
+
"<|23.70|>": 51550,
|
838 |
+
"<|23.72|>": 51551,
|
839 |
+
"<|23.74|>": 51552,
|
840 |
+
"<|23.76|>": 51553,
|
841 |
+
"<|23.78|>": 51554,
|
842 |
+
"<|23.80|>": 51555,
|
843 |
+
"<|23.82|>": 51556,
|
844 |
+
"<|23.84|>": 51557,
|
845 |
+
"<|23.86|>": 51558,
|
846 |
+
"<|23.88|>": 51559,
|
847 |
+
"<|23.90|>": 51560,
|
848 |
+
"<|23.92|>": 51561,
|
849 |
+
"<|23.94|>": 51562,
|
850 |
+
"<|23.96|>": 51563,
|
851 |
+
"<|23.98|>": 51564,
|
852 |
+
"<|24.00|>": 51565,
|
853 |
+
"<|24.02|>": 51566,
|
854 |
+
"<|24.04|>": 51567,
|
855 |
+
"<|24.06|>": 51568,
|
856 |
+
"<|24.08|>": 51569,
|
857 |
+
"<|24.10|>": 51570,
|
858 |
+
"<|24.12|>": 51571,
|
859 |
+
"<|24.14|>": 51572,
|
860 |
+
"<|24.16|>": 51573,
|
861 |
+
"<|24.18|>": 51574,
|
862 |
+
"<|24.20|>": 51575,
|
863 |
+
"<|24.22|>": 51576,
|
864 |
+
"<|24.24|>": 51577,
|
865 |
+
"<|24.26|>": 51578,
|
866 |
+
"<|24.28|>": 51579,
|
867 |
+
"<|24.30|>": 51580,
|
868 |
+
"<|24.32|>": 51581,
|
869 |
+
"<|24.34|>": 51582,
|
870 |
+
"<|24.36|>": 51583,
|
871 |
+
"<|24.38|>": 51584,
|
872 |
+
"<|24.40|>": 51585,
|
873 |
+
"<|24.42|>": 51586,
|
874 |
+
"<|24.44|>": 51587,
|
875 |
+
"<|24.46|>": 51588,
|
876 |
+
"<|24.48|>": 51589,
|
877 |
+
"<|24.50|>": 51590,
|
878 |
+
"<|24.52|>": 51591,
|
879 |
+
"<|24.54|>": 51592,
|
880 |
+
"<|24.56|>": 51593,
|
881 |
+
"<|24.58|>": 51594,
|
882 |
+
"<|24.60|>": 51595,
|
883 |
+
"<|24.62|>": 51596,
|
884 |
+
"<|24.64|>": 51597,
|
885 |
+
"<|24.66|>": 51598,
|
886 |
+
"<|24.68|>": 51599,
|
887 |
+
"<|24.70|>": 51600,
|
888 |
+
"<|24.72|>": 51601,
|
889 |
+
"<|24.74|>": 51602,
|
890 |
+
"<|24.76|>": 51603,
|
891 |
+
"<|24.78|>": 51604,
|
892 |
+
"<|24.80|>": 51605,
|
893 |
+
"<|24.82|>": 51606,
|
894 |
+
"<|24.84|>": 51607,
|
895 |
+
"<|24.86|>": 51608,
|
896 |
+
"<|24.88|>": 51609,
|
897 |
+
"<|24.90|>": 51610,
|
898 |
+
"<|24.92|>": 51611,
|
899 |
+
"<|24.94|>": 51612,
|
900 |
+
"<|24.96|>": 51613,
|
901 |
+
"<|24.98|>": 51614,
|
902 |
+
"<|25.00|>": 51615,
|
903 |
+
"<|25.02|>": 51616,
|
904 |
+
"<|25.04|>": 51617,
|
905 |
+
"<|25.06|>": 51618,
|
906 |
+
"<|25.08|>": 51619,
|
907 |
+
"<|25.10|>": 51620,
|
908 |
+
"<|25.12|>": 51621,
|
909 |
+
"<|25.14|>": 51622,
|
910 |
+
"<|25.16|>": 51623,
|
911 |
+
"<|25.18|>": 51624,
|
912 |
+
"<|25.20|>": 51625,
|
913 |
+
"<|25.22|>": 51626,
|
914 |
+
"<|25.24|>": 51627,
|
915 |
+
"<|25.26|>": 51628,
|
916 |
+
"<|25.28|>": 51629,
|
917 |
+
"<|25.30|>": 51630,
|
918 |
+
"<|25.32|>": 51631,
|
919 |
+
"<|25.34|>": 51632,
|
920 |
+
"<|25.36|>": 51633,
|
921 |
+
"<|25.38|>": 51634,
|
922 |
+
"<|25.40|>": 51635,
|
923 |
+
"<|25.42|>": 51636,
|
924 |
+
"<|25.44|>": 51637,
|
925 |
+
"<|25.46|>": 51638,
|
926 |
+
"<|25.48|>": 51639,
|
927 |
+
"<|25.50|>": 51640,
|
928 |
+
"<|25.52|>": 51641,
|
929 |
+
"<|25.54|>": 51642,
|
930 |
+
"<|25.56|>": 51643,
|
931 |
+
"<|25.58|>": 51644,
|
932 |
+
"<|25.60|>": 51645,
|
933 |
+
"<|25.62|>": 51646,
|
934 |
+
"<|25.64|>": 51647,
|
935 |
+
"<|25.66|>": 51648,
|
936 |
+
"<|25.68|>": 51649,
|
937 |
+
"<|25.70|>": 51650,
|
938 |
+
"<|25.72|>": 51651,
|
939 |
+
"<|25.74|>": 51652,
|
940 |
+
"<|25.76|>": 51653,
|
941 |
+
"<|25.78|>": 51654,
|
942 |
+
"<|25.80|>": 51655,
|
943 |
+
"<|25.82|>": 51656,
|
944 |
+
"<|25.84|>": 51657,
|
945 |
+
"<|25.86|>": 51658,
|
946 |
+
"<|25.88|>": 51659,
|
947 |
+
"<|25.90|>": 51660,
|
948 |
+
"<|25.92|>": 51661,
|
949 |
+
"<|25.94|>": 51662,
|
950 |
+
"<|25.96|>": 51663,
|
951 |
+
"<|25.98|>": 51664,
|
952 |
+
"<|26.00|>": 51665,
|
953 |
+
"<|26.02|>": 51666,
|
954 |
+
"<|26.04|>": 51667,
|
955 |
+
"<|26.06|>": 51668,
|
956 |
+
"<|26.08|>": 51669,
|
957 |
+
"<|26.10|>": 51670,
|
958 |
+
"<|26.12|>": 51671,
|
959 |
+
"<|26.14|>": 51672,
|
960 |
+
"<|26.16|>": 51673,
|
961 |
+
"<|26.18|>": 51674,
|
962 |
+
"<|26.20|>": 51675,
|
963 |
+
"<|26.22|>": 51676,
|
964 |
+
"<|26.24|>": 51677,
|
965 |
+
"<|26.26|>": 51678,
|
966 |
+
"<|26.28|>": 51679,
|
967 |
+
"<|26.30|>": 51680,
|
968 |
+
"<|26.32|>": 51681,
|
969 |
+
"<|26.34|>": 51682,
|
970 |
+
"<|26.36|>": 51683,
|
971 |
+
"<|26.38|>": 51684,
|
972 |
+
"<|26.40|>": 51685,
|
973 |
+
"<|26.42|>": 51686,
|
974 |
+
"<|26.44|>": 51687,
|
975 |
+
"<|26.46|>": 51688,
|
976 |
+
"<|26.48|>": 51689,
|
977 |
+
"<|26.50|>": 51690,
|
978 |
+
"<|26.52|>": 51691,
|
979 |
+
"<|26.54|>": 51692,
|
980 |
+
"<|26.56|>": 51693,
|
981 |
+
"<|26.58|>": 51694,
|
982 |
+
"<|26.60|>": 51695,
|
983 |
+
"<|26.62|>": 51696,
|
984 |
+
"<|26.64|>": 51697,
|
985 |
+
"<|26.66|>": 51698,
|
986 |
+
"<|26.68|>": 51699,
|
987 |
+
"<|26.70|>": 51700,
|
988 |
+
"<|26.72|>": 51701,
|
989 |
+
"<|26.74|>": 51702,
|
990 |
+
"<|26.76|>": 51703,
|
991 |
+
"<|26.78|>": 51704,
|
992 |
+
"<|26.80|>": 51705,
|
993 |
+
"<|26.82|>": 51706,
|
994 |
+
"<|26.84|>": 51707,
|
995 |
+
"<|26.86|>": 51708,
|
996 |
+
"<|26.88|>": 51709,
|
997 |
+
"<|26.90|>": 51710,
|
998 |
+
"<|26.92|>": 51711,
|
999 |
+
"<|26.94|>": 51712,
|
1000 |
+
"<|26.96|>": 51713,
|
1001 |
+
"<|26.98|>": 51714,
|
1002 |
+
"<|27.00|>": 51715,
|
1003 |
+
"<|27.02|>": 51716,
|
1004 |
+
"<|27.04|>": 51717,
|
1005 |
+
"<|27.06|>": 51718,
|
1006 |
+
"<|27.08|>": 51719,
|
1007 |
+
"<|27.10|>": 51720,
|
1008 |
+
"<|27.12|>": 51721,
|
1009 |
+
"<|27.14|>": 51722,
|
1010 |
+
"<|27.16|>": 51723,
|
1011 |
+
"<|27.18|>": 51724,
|
1012 |
+
"<|27.20|>": 51725,
|
1013 |
+
"<|27.22|>": 51726,
|
1014 |
+
"<|27.24|>": 51727,
|
1015 |
+
"<|27.26|>": 51728,
|
1016 |
+
"<|27.28|>": 51729,
|
1017 |
+
"<|27.30|>": 51730,
|
1018 |
+
"<|27.32|>": 51731,
|
1019 |
+
"<|27.34|>": 51732,
|
1020 |
+
"<|27.36|>": 51733,
|
1021 |
+
"<|27.38|>": 51734,
|
1022 |
+
"<|27.40|>": 51735,
|
1023 |
+
"<|27.42|>": 51736,
|
1024 |
+
"<|27.44|>": 51737,
|
1025 |
+
"<|27.46|>": 51738,
|
1026 |
+
"<|27.48|>": 51739,
|
1027 |
+
"<|27.50|>": 51740,
|
1028 |
+
"<|27.52|>": 51741,
|
1029 |
+
"<|27.54|>": 51742,
|
1030 |
+
"<|27.56|>": 51743,
|
1031 |
+
"<|27.58|>": 51744,
|
1032 |
+
"<|27.60|>": 51745,
|
1033 |
+
"<|27.62|>": 51746,
|
1034 |
+
"<|27.64|>": 51747,
|
1035 |
+
"<|27.66|>": 51748,
|
1036 |
+
"<|27.68|>": 51749,
|
1037 |
+
"<|27.70|>": 51750,
|
1038 |
+
"<|27.72|>": 51751,
|
1039 |
+
"<|27.74|>": 51752,
|
1040 |
+
"<|27.76|>": 51753,
|
1041 |
+
"<|27.78|>": 51754,
|
1042 |
+
"<|27.80|>": 51755,
|
1043 |
+
"<|27.82|>": 51756,
|
1044 |
+
"<|27.84|>": 51757,
|
1045 |
+
"<|27.86|>": 51758,
|
1046 |
+
"<|27.88|>": 51759,
|
1047 |
+
"<|27.90|>": 51760,
|
1048 |
+
"<|27.92|>": 51761,
|
1049 |
+
"<|27.94|>": 51762,
|
1050 |
+
"<|27.96|>": 51763,
|
1051 |
+
"<|27.98|>": 51764,
|
1052 |
+
"<|28.00|>": 51765,
|
1053 |
+
"<|28.02|>": 51766,
|
1054 |
+
"<|28.04|>": 51767,
|
1055 |
+
"<|28.06|>": 51768,
|
1056 |
+
"<|28.08|>": 51769,
|
1057 |
+
"<|28.10|>": 51770,
|
1058 |
+
"<|28.12|>": 51771,
|
1059 |
+
"<|28.14|>": 51772,
|
1060 |
+
"<|28.16|>": 51773,
|
1061 |
+
"<|28.18|>": 51774,
|
1062 |
+
"<|28.20|>": 51775,
|
1063 |
+
"<|28.22|>": 51776,
|
1064 |
+
"<|28.24|>": 51777,
|
1065 |
+
"<|28.26|>": 51778,
|
1066 |
+
"<|28.28|>": 51779,
|
1067 |
+
"<|28.30|>": 51780,
|
1068 |
+
"<|28.32|>": 51781,
|
1069 |
+
"<|28.34|>": 51782,
|
1070 |
+
"<|28.36|>": 51783,
|
1071 |
+
"<|28.38|>": 51784,
|
1072 |
+
"<|28.40|>": 51785,
|
1073 |
+
"<|28.42|>": 51786,
|
1074 |
+
"<|28.44|>": 51787,
|
1075 |
+
"<|28.46|>": 51788,
|
1076 |
+
"<|28.48|>": 51789,
|
1077 |
+
"<|28.50|>": 51790,
|
1078 |
+
"<|28.52|>": 51791,
|
1079 |
+
"<|28.54|>": 51792,
|
1080 |
+
"<|28.56|>": 51793,
|
1081 |
+
"<|28.58|>": 51794,
|
1082 |
+
"<|28.60|>": 51795,
|
1083 |
+
"<|28.62|>": 51796,
|
1084 |
+
"<|28.64|>": 51797,
|
1085 |
+
"<|28.66|>": 51798,
|
1086 |
+
"<|28.68|>": 51799,
|
1087 |
+
"<|28.70|>": 51800,
|
1088 |
+
"<|28.72|>": 51801,
|
1089 |
+
"<|28.74|>": 51802,
|
1090 |
+
"<|28.76|>": 51803,
|
1091 |
+
"<|28.78|>": 51804,
|
1092 |
+
"<|28.80|>": 51805,
|
1093 |
+
"<|28.82|>": 51806,
|
1094 |
+
"<|28.84|>": 51807,
|
1095 |
+
"<|28.86|>": 51808,
|
1096 |
+
"<|28.88|>": 51809,
|
1097 |
+
"<|28.90|>": 51810,
|
1098 |
+
"<|28.92|>": 51811,
|
1099 |
+
"<|28.94|>": 51812,
|
1100 |
+
"<|28.96|>": 51813,
|
1101 |
+
"<|28.98|>": 51814,
|
1102 |
+
"<|29.00|>": 51815,
|
1103 |
+
"<|29.02|>": 51816,
|
1104 |
+
"<|29.04|>": 51817,
|
1105 |
+
"<|29.06|>": 51818,
|
1106 |
+
"<|29.08|>": 51819,
|
1107 |
+
"<|29.10|>": 51820,
|
1108 |
+
"<|29.12|>": 51821,
|
1109 |
+
"<|29.14|>": 51822,
|
1110 |
+
"<|29.16|>": 51823,
|
1111 |
+
"<|29.18|>": 51824,
|
1112 |
+
"<|29.20|>": 51825,
|
1113 |
+
"<|29.22|>": 51826,
|
1114 |
+
"<|29.24|>": 51827,
|
1115 |
+
"<|29.26|>": 51828,
|
1116 |
+
"<|29.28|>": 51829,
|
1117 |
+
"<|29.30|>": 51830,
|
1118 |
+
"<|29.32|>": 51831,
|
1119 |
+
"<|29.34|>": 51832,
|
1120 |
+
"<|29.36|>": 51833,
|
1121 |
+
"<|29.38|>": 51834,
|
1122 |
+
"<|29.40|>": 51835,
|
1123 |
+
"<|29.42|>": 51836,
|
1124 |
+
"<|29.44|>": 51837,
|
1125 |
+
"<|29.46|>": 51838,
|
1126 |
+
"<|29.48|>": 51839,
|
1127 |
+
"<|29.50|>": 51840,
|
1128 |
+
"<|29.52|>": 51841,
|
1129 |
+
"<|29.54|>": 51842,
|
1130 |
+
"<|29.56|>": 51843,
|
1131 |
+
"<|29.58|>": 51844,
|
1132 |
+
"<|29.60|>": 51845,
|
1133 |
+
"<|29.62|>": 51846,
|
1134 |
+
"<|29.64|>": 51847,
|
1135 |
+
"<|29.66|>": 51848,
|
1136 |
+
"<|29.68|>": 51849,
|
1137 |
+
"<|29.70|>": 51850,
|
1138 |
+
"<|29.72|>": 51851,
|
1139 |
+
"<|29.74|>": 51852,
|
1140 |
+
"<|29.76|>": 51853,
|
1141 |
+
"<|29.78|>": 51854,
|
1142 |
+
"<|29.80|>": 51855,
|
1143 |
+
"<|29.82|>": 51856,
|
1144 |
+
"<|29.84|>": 51857,
|
1145 |
+
"<|29.86|>": 51858,
|
1146 |
+
"<|29.88|>": 51859,
|
1147 |
+
"<|29.90|>": 51860,
|
1148 |
+
"<|29.92|>": 51861,
|
1149 |
+
"<|29.94|>": 51862,
|
1150 |
+
"<|29.96|>": 51863,
|
1151 |
+
"<|29.98|>": 51864,
|
1152 |
+
"<|3.00|>": 50515,
|
1153 |
+
"<|3.02|>": 50516,
|
1154 |
+
"<|3.04|>": 50517,
|
1155 |
+
"<|3.06|>": 50518,
|
1156 |
+
"<|3.08|>": 50519,
|
1157 |
+
"<|3.10|>": 50520,
|
1158 |
+
"<|3.12|>": 50521,
|
1159 |
+
"<|3.14|>": 50522,
|
1160 |
+
"<|3.16|>": 50523,
|
1161 |
+
"<|3.18|>": 50524,
|
1162 |
+
"<|3.20|>": 50525,
|
1163 |
+
"<|3.22|>": 50526,
|
1164 |
+
"<|3.24|>": 50527,
|
1165 |
+
"<|3.26|>": 50528,
|
1166 |
+
"<|3.28|>": 50529,
|
1167 |
+
"<|3.30|>": 50530,
|
1168 |
+
"<|3.32|>": 50531,
|
1169 |
+
"<|3.34|>": 50532,
|
1170 |
+
"<|3.36|>": 50533,
|
1171 |
+
"<|3.38|>": 50534,
|
1172 |
+
"<|3.40|>": 50535,
|
1173 |
+
"<|3.42|>": 50536,
|
1174 |
+
"<|3.44|>": 50537,
|
1175 |
+
"<|3.46|>": 50538,
|
1176 |
+
"<|3.48|>": 50539,
|
1177 |
+
"<|3.50|>": 50540,
|
1178 |
+
"<|3.52|>": 50541,
|
1179 |
+
"<|3.54|>": 50542,
|
1180 |
+
"<|3.56|>": 50543,
|
1181 |
+
"<|3.58|>": 50544,
|
1182 |
+
"<|3.60|>": 50545,
|
1183 |
+
"<|3.62|>": 50546,
|
1184 |
+
"<|3.64|>": 50547,
|
1185 |
+
"<|3.66|>": 50548,
|
1186 |
+
"<|3.68|>": 50549,
|
1187 |
+
"<|3.70|>": 50550,
|
1188 |
+
"<|3.72|>": 50551,
|
1189 |
+
"<|3.74|>": 50552,
|
1190 |
+
"<|3.76|>": 50553,
|
1191 |
+
"<|3.78|>": 50554,
|
1192 |
+
"<|3.80|>": 50555,
|
1193 |
+
"<|3.82|>": 50556,
|
1194 |
+
"<|3.84|>": 50557,
|
1195 |
+
"<|3.86|>": 50558,
|
1196 |
+
"<|3.88|>": 50559,
|
1197 |
+
"<|3.90|>": 50560,
|
1198 |
+
"<|3.92|>": 50561,
|
1199 |
+
"<|3.94|>": 50562,
|
1200 |
+
"<|3.96|>": 50563,
|
1201 |
+
"<|3.98|>": 50564,
|
1202 |
+
"<|30.00|>": 51865,
|
1203 |
+
"<|4.00|>": 50565,
|
1204 |
+
"<|4.02|>": 50566,
|
1205 |
+
"<|4.04|>": 50567,
|
1206 |
+
"<|4.06|>": 50568,
|
1207 |
+
"<|4.08|>": 50569,
|
1208 |
+
"<|4.10|>": 50570,
|
1209 |
+
"<|4.12|>": 50571,
|
1210 |
+
"<|4.14|>": 50572,
|
1211 |
+
"<|4.16|>": 50573,
|
1212 |
+
"<|4.18|>": 50574,
|
1213 |
+
"<|4.20|>": 50575,
|
1214 |
+
"<|4.22|>": 50576,
|
1215 |
+
"<|4.24|>": 50577,
|
1216 |
+
"<|4.26|>": 50578,
|
1217 |
+
"<|4.28|>": 50579,
|
1218 |
+
"<|4.30|>": 50580,
|
1219 |
+
"<|4.32|>": 50581,
|
1220 |
+
"<|4.34|>": 50582,
|
1221 |
+
"<|4.36|>": 50583,
|
1222 |
+
"<|4.38|>": 50584,
|
1223 |
+
"<|4.40|>": 50585,
|
1224 |
+
"<|4.42|>": 50586,
|
1225 |
+
"<|4.44|>": 50587,
|
1226 |
+
"<|4.46|>": 50588,
|
1227 |
+
"<|4.48|>": 50589,
|
1228 |
+
"<|4.50|>": 50590,
|
1229 |
+
"<|4.52|>": 50591,
|
1230 |
+
"<|4.54|>": 50592,
|
1231 |
+
"<|4.56|>": 50593,
|
1232 |
+
"<|4.58|>": 50594,
|
1233 |
+
"<|4.60|>": 50595,
|
1234 |
+
"<|4.62|>": 50596,
|
1235 |
+
"<|4.64|>": 50597,
|
1236 |
+
"<|4.66|>": 50598,
|
1237 |
+
"<|4.68|>": 50599,
|
1238 |
+
"<|4.70|>": 50600,
|
1239 |
+
"<|4.72|>": 50601,
|
1240 |
+
"<|4.74|>": 50602,
|
1241 |
+
"<|4.76|>": 50603,
|
1242 |
+
"<|4.78|>": 50604,
|
1243 |
+
"<|4.80|>": 50605,
|
1244 |
+
"<|4.82|>": 50606,
|
1245 |
+
"<|4.84|>": 50607,
|
1246 |
+
"<|4.86|>": 50608,
|
1247 |
+
"<|4.88|>": 50609,
|
1248 |
+
"<|4.90|>": 50610,
|
1249 |
+
"<|4.92|>": 50611,
|
1250 |
+
"<|4.94|>": 50612,
|
1251 |
+
"<|4.96|>": 50613,
|
1252 |
+
"<|4.98|>": 50614,
|
1253 |
+
"<|5.00|>": 50615,
|
1254 |
+
"<|5.02|>": 50616,
|
1255 |
+
"<|5.04|>": 50617,
|
1256 |
+
"<|5.06|>": 50618,
|
1257 |
+
"<|5.08|>": 50619,
|
1258 |
+
"<|5.10|>": 50620,
|
1259 |
+
"<|5.12|>": 50621,
|
1260 |
+
"<|5.14|>": 50622,
|
1261 |
+
"<|5.16|>": 50623,
|
1262 |
+
"<|5.18|>": 50624,
|
1263 |
+
"<|5.20|>": 50625,
|
1264 |
+
"<|5.22|>": 50626,
|
1265 |
+
"<|5.24|>": 50627,
|
1266 |
+
"<|5.26|>": 50628,
|
1267 |
+
"<|5.28|>": 50629,
|
1268 |
+
"<|5.30|>": 50630,
|
1269 |
+
"<|5.32|>": 50631,
|
1270 |
+
"<|5.34|>": 50632,
|
1271 |
+
"<|5.36|>": 50633,
|
1272 |
+
"<|5.38|>": 50634,
|
1273 |
+
"<|5.40|>": 50635,
|
1274 |
+
"<|5.42|>": 50636,
|
1275 |
+
"<|5.44|>": 50637,
|
1276 |
+
"<|5.46|>": 50638,
|
1277 |
+
"<|5.48|>": 50639,
|
1278 |
+
"<|5.50|>": 50640,
|
1279 |
+
"<|5.52|>": 50641,
|
1280 |
+
"<|5.54|>": 50642,
|
1281 |
+
"<|5.56|>": 50643,
|
1282 |
+
"<|5.58|>": 50644,
|
1283 |
+
"<|5.60|>": 50645,
|
1284 |
+
"<|5.62|>": 50646,
|
1285 |
+
"<|5.64|>": 50647,
|
1286 |
+
"<|5.66|>": 50648,
|
1287 |
+
"<|5.68|>": 50649,
|
1288 |
+
"<|5.70|>": 50650,
|
1289 |
+
"<|5.72|>": 50651,
|
1290 |
+
"<|5.74|>": 50652,
|
1291 |
+
"<|5.76|>": 50653,
|
1292 |
+
"<|5.78|>": 50654,
|
1293 |
+
"<|5.80|>": 50655,
|
1294 |
+
"<|5.82|>": 50656,
|
1295 |
+
"<|5.84|>": 50657,
|
1296 |
+
"<|5.86|>": 50658,
|
1297 |
+
"<|5.88|>": 50659,
|
1298 |
+
"<|5.90|>": 50660,
|
1299 |
+
"<|5.92|>": 50661,
|
1300 |
+
"<|5.94|>": 50662,
|
1301 |
+
"<|5.96|>": 50663,
|
1302 |
+
"<|5.98|>": 50664,
|
1303 |
+
"<|6.00|>": 50665,
|
1304 |
+
"<|6.02|>": 50666,
|
1305 |
+
"<|6.04|>": 50667,
|
1306 |
+
"<|6.06|>": 50668,
|
1307 |
+
"<|6.08|>": 50669,
|
1308 |
+
"<|6.10|>": 50670,
|
1309 |
+
"<|6.12|>": 50671,
|
1310 |
+
"<|6.14|>": 50672,
|
1311 |
+
"<|6.16|>": 50673,
|
1312 |
+
"<|6.18|>": 50674,
|
1313 |
+
"<|6.20|>": 50675,
|
1314 |
+
"<|6.22|>": 50676,
|
1315 |
+
"<|6.24|>": 50677,
|
1316 |
+
"<|6.26|>": 50678,
|
1317 |
+
"<|6.28|>": 50679,
|
1318 |
+
"<|6.30|>": 50680,
|
1319 |
+
"<|6.32|>": 50681,
|
1320 |
+
"<|6.34|>": 50682,
|
1321 |
+
"<|6.36|>": 50683,
|
1322 |
+
"<|6.38|>": 50684,
|
1323 |
+
"<|6.40|>": 50685,
|
1324 |
+
"<|6.42|>": 50686,
|
1325 |
+
"<|6.44|>": 50687,
|
1326 |
+
"<|6.46|>": 50688,
|
1327 |
+
"<|6.48|>": 50689,
|
1328 |
+
"<|6.50|>": 50690,
|
1329 |
+
"<|6.52|>": 50691,
|
1330 |
+
"<|6.54|>": 50692,
|
1331 |
+
"<|6.56|>": 50693,
|
1332 |
+
"<|6.58|>": 50694,
|
1333 |
+
"<|6.60|>": 50695,
|
1334 |
+
"<|6.62|>": 50696,
|
1335 |
+
"<|6.64|>": 50697,
|
1336 |
+
"<|6.66|>": 50698,
|
1337 |
+
"<|6.68|>": 50699,
|
1338 |
+
"<|6.70|>": 50700,
|
1339 |
+
"<|6.72|>": 50701,
|
1340 |
+
"<|6.74|>": 50702,
|
1341 |
+
"<|6.76|>": 50703,
|
1342 |
+
"<|6.78|>": 50704,
|
1343 |
+
"<|6.80|>": 50705,
|
1344 |
+
"<|6.82|>": 50706,
|
1345 |
+
"<|6.84|>": 50707,
|
1346 |
+
"<|6.86|>": 50708,
|
1347 |
+
"<|6.88|>": 50709,
|
1348 |
+
"<|6.90|>": 50710,
|
1349 |
+
"<|6.92|>": 50711,
|
1350 |
+
"<|6.94|>": 50712,
|
1351 |
+
"<|6.96|>": 50713,
|
1352 |
+
"<|6.98|>": 50714,
|
1353 |
+
"<|7.00|>": 50715,
|
1354 |
+
"<|7.02|>": 50716,
|
1355 |
+
"<|7.04|>": 50717,
|
1356 |
+
"<|7.06|>": 50718,
|
1357 |
+
"<|7.08|>": 50719,
|
1358 |
+
"<|7.10|>": 50720,
|
1359 |
+
"<|7.12|>": 50721,
|
1360 |
+
"<|7.14|>": 50722,
|
1361 |
+
"<|7.16|>": 50723,
|
1362 |
+
"<|7.18|>": 50724,
|
1363 |
+
"<|7.20|>": 50725,
|
1364 |
+
"<|7.22|>": 50726,
|
1365 |
+
"<|7.24|>": 50727,
|
1366 |
+
"<|7.26|>": 50728,
|
1367 |
+
"<|7.28|>": 50729,
|
1368 |
+
"<|7.30|>": 50730,
|
1369 |
+
"<|7.32|>": 50731,
|
1370 |
+
"<|7.34|>": 50732,
|
1371 |
+
"<|7.36|>": 50733,
|
1372 |
+
"<|7.38|>": 50734,
|
1373 |
+
"<|7.40|>": 50735,
|
1374 |
+
"<|7.42|>": 50736,
|
1375 |
+
"<|7.44|>": 50737,
|
1376 |
+
"<|7.46|>": 50738,
|
1377 |
+
"<|7.48|>": 50739,
|
1378 |
+
"<|7.50|>": 50740,
|
1379 |
+
"<|7.52|>": 50741,
|
1380 |
+
"<|7.54|>": 50742,
|
1381 |
+
"<|7.56|>": 50743,
|
1382 |
+
"<|7.58|>": 50744,
|
1383 |
+
"<|7.60|>": 50745,
|
1384 |
+
"<|7.62|>": 50746,
|
1385 |
+
"<|7.64|>": 50747,
|
1386 |
+
"<|7.66|>": 50748,
|
1387 |
+
"<|7.68|>": 50749,
|
1388 |
+
"<|7.70|>": 50750,
|
1389 |
+
"<|7.72|>": 50751,
|
1390 |
+
"<|7.74|>": 50752,
|
1391 |
+
"<|7.76|>": 50753,
|
1392 |
+
"<|7.78|>": 50754,
|
1393 |
+
"<|7.80|>": 50755,
|
1394 |
+
"<|7.82|>": 50756,
|
1395 |
+
"<|7.84|>": 50757,
|
1396 |
+
"<|7.86|>": 50758,
|
1397 |
+
"<|7.88|>": 50759,
|
1398 |
+
"<|7.90|>": 50760,
|
1399 |
+
"<|7.92|>": 50761,
|
1400 |
+
"<|7.94|>": 50762,
|
1401 |
+
"<|7.96|>": 50763,
|
1402 |
+
"<|7.98|>": 50764,
|
1403 |
+
"<|8.00|>": 50765,
|
1404 |
+
"<|8.02|>": 50766,
|
1405 |
+
"<|8.04|>": 50767,
|
1406 |
+
"<|8.06|>": 50768,
|
1407 |
+
"<|8.08|>": 50769,
|
1408 |
+
"<|8.10|>": 50770,
|
1409 |
+
"<|8.12|>": 50771,
|
1410 |
+
"<|8.14|>": 50772,
|
1411 |
+
"<|8.16|>": 50773,
|
1412 |
+
"<|8.18|>": 50774,
|
1413 |
+
"<|8.20|>": 50775,
|
1414 |
+
"<|8.22|>": 50776,
|
1415 |
+
"<|8.24|>": 50777,
|
1416 |
+
"<|8.26|>": 50778,
|
1417 |
+
"<|8.28|>": 50779,
|
1418 |
+
"<|8.30|>": 50780,
|
1419 |
+
"<|8.32|>": 50781,
|
1420 |
+
"<|8.34|>": 50782,
|
1421 |
+
"<|8.36|>": 50783,
|
1422 |
+
"<|8.38|>": 50784,
|
1423 |
+
"<|8.40|>": 50785,
|
1424 |
+
"<|8.42|>": 50786,
|
1425 |
+
"<|8.44|>": 50787,
|
1426 |
+
"<|8.46|>": 50788,
|
1427 |
+
"<|8.48|>": 50789,
|
1428 |
+
"<|8.50|>": 50790,
|
1429 |
+
"<|8.52|>": 50791,
|
1430 |
+
"<|8.54|>": 50792,
|
1431 |
+
"<|8.56|>": 50793,
|
1432 |
+
"<|8.58|>": 50794,
|
1433 |
+
"<|8.60|>": 50795,
|
1434 |
+
"<|8.62|>": 50796,
|
1435 |
+
"<|8.64|>": 50797,
|
1436 |
+
"<|8.66|>": 50798,
|
1437 |
+
"<|8.68|>": 50799,
|
1438 |
+
"<|8.70|>": 50800,
|
1439 |
+
"<|8.72|>": 50801,
|
1440 |
+
"<|8.74|>": 50802,
|
1441 |
+
"<|8.76|>": 50803,
|
1442 |
+
"<|8.78|>": 50804,
|
1443 |
+
"<|8.80|>": 50805,
|
1444 |
+
"<|8.82|>": 50806,
|
1445 |
+
"<|8.84|>": 50807,
|
1446 |
+
"<|8.86|>": 50808,
|
1447 |
+
"<|8.88|>": 50809,
|
1448 |
+
"<|8.90|>": 50810,
|
1449 |
+
"<|8.92|>": 50811,
|
1450 |
+
"<|8.94|>": 50812,
|
1451 |
+
"<|8.96|>": 50813,
|
1452 |
+
"<|8.98|>": 50814,
|
1453 |
+
"<|9.00|>": 50815,
|
1454 |
+
"<|9.02|>": 50816,
|
1455 |
+
"<|9.04|>": 50817,
|
1456 |
+
"<|9.06|>": 50818,
|
1457 |
+
"<|9.08|>": 50819,
|
1458 |
+
"<|9.10|>": 50820,
|
1459 |
+
"<|9.12|>": 50821,
|
1460 |
+
"<|9.14|>": 50822,
|
1461 |
+
"<|9.16|>": 50823,
|
1462 |
+
"<|9.18|>": 50824,
|
1463 |
+
"<|9.20|>": 50825,
|
1464 |
+
"<|9.22|>": 50826,
|
1465 |
+
"<|9.24|>": 50827,
|
1466 |
+
"<|9.26|>": 50828,
|
1467 |
+
"<|9.28|>": 50829,
|
1468 |
+
"<|9.30|>": 50830,
|
1469 |
+
"<|9.32|>": 50831,
|
1470 |
+
"<|9.34|>": 50832,
|
1471 |
+
"<|9.36|>": 50833,
|
1472 |
+
"<|9.38|>": 50834,
|
1473 |
+
"<|9.40|>": 50835,
|
1474 |
+
"<|9.42|>": 50836,
|
1475 |
+
"<|9.44|>": 50837,
|
1476 |
+
"<|9.46|>": 50838,
|
1477 |
+
"<|9.48|>": 50839,
|
1478 |
+
"<|9.50|>": 50840,
|
1479 |
+
"<|9.52|>": 50841,
|
1480 |
+
"<|9.54|>": 50842,
|
1481 |
+
"<|9.56|>": 50843,
|
1482 |
+
"<|9.58|>": 50844,
|
1483 |
+
"<|9.60|>": 50845,
|
1484 |
+
"<|9.62|>": 50846,
|
1485 |
+
"<|9.64|>": 50847,
|
1486 |
+
"<|9.66|>": 50848,
|
1487 |
+
"<|9.68|>": 50849,
|
1488 |
+
"<|9.70|>": 50850,
|
1489 |
+
"<|9.72|>": 50851,
|
1490 |
+
"<|9.74|>": 50852,
|
1491 |
+
"<|9.76|>": 50853,
|
1492 |
+
"<|9.78|>": 50854,
|
1493 |
+
"<|9.80|>": 50855,
|
1494 |
+
"<|9.82|>": 50856,
|
1495 |
+
"<|9.84|>": 50857,
|
1496 |
+
"<|9.86|>": 50858,
|
1497 |
+
"<|9.88|>": 50859,
|
1498 |
+
"<|9.90|>": 50860,
|
1499 |
+
"<|9.92|>": 50861,
|
1500 |
+
"<|9.94|>": 50862,
|
1501 |
+
"<|9.96|>": 50863,
|
1502 |
+
"<|9.98|>": 50864,
|
1503 |
+
"<|af|>": 50327,
|
1504 |
+
"<|am|>": 50334,
|
1505 |
+
"<|ar|>": 50272,
|
1506 |
+
"<|as|>": 50350,
|
1507 |
+
"<|az|>": 50304,
|
1508 |
+
"<|ba|>": 50355,
|
1509 |
+
"<|be|>": 50330,
|
1510 |
+
"<|bg|>": 50292,
|
1511 |
+
"<|bn|>": 50302,
|
1512 |
+
"<|bo|>": 50347,
|
1513 |
+
"<|br|>": 50309,
|
1514 |
+
"<|bs|>": 50315,
|
1515 |
+
"<|ca|>": 50270,
|
1516 |
+
"<|cs|>": 50283,
|
1517 |
+
"<|cy|>": 50297,
|
1518 |
+
"<|da|>": 50285,
|
1519 |
+
"<|de|>": 50261,
|
1520 |
+
"<|el|>": 50281,
|
1521 |
+
"<|endoftext|>": 50257,
|
1522 |
+
"<|en|>": 50259,
|
1523 |
+
"<|es|>": 50262,
|
1524 |
+
"<|et|>": 50307,
|
1525 |
+
"<|eu|>": 50310,
|
1526 |
+
"<|fa|>": 50300,
|
1527 |
+
"<|fi|>": 50277,
|
1528 |
+
"<|fo|>": 50338,
|
1529 |
+
"<|fr|>": 50265,
|
1530 |
+
"<|gl|>": 50319,
|
1531 |
+
"<|gu|>": 50333,
|
1532 |
+
"<|haw|>": 50352,
|
1533 |
+
"<|ha|>": 50354,
|
1534 |
+
"<|he|>": 50279,
|
1535 |
+
"<|hi|>": 50276,
|
1536 |
+
"<|hr|>": 50291,
|
1537 |
+
"<|ht|>": 50339,
|
1538 |
+
"<|hu|>": 50286,
|
1539 |
+
"<|hy|>": 50312,
|
1540 |
+
"<|id|>": 50275,
|
1541 |
+
"<|is|>": 50311,
|
1542 |
+
"<|it|>": 50274,
|
1543 |
+
"<|ja|>": 50266,
|
1544 |
+
"<|jw|>": 50356,
|
1545 |
+
"<|ka|>": 50329,
|
1546 |
+
"<|kk|>": 50316,
|
1547 |
+
"<|km|>": 50323,
|
1548 |
+
"<|kn|>": 50306,
|
1549 |
+
"<|ko|>": 50264,
|
1550 |
+
"<|la|>": 50294,
|
1551 |
+
"<|lb|>": 50345,
|
1552 |
+
"<|ln|>": 50353,
|
1553 |
+
"<|lo|>": 50336,
|
1554 |
+
"<|lt|>": 50293,
|
1555 |
+
"<|lv|>": 50301,
|
1556 |
+
"<|mg|>": 50349,
|
1557 |
+
"<|mi|>": 50295,
|
1558 |
+
"<|mk|>": 50308,
|
1559 |
+
"<|ml|>": 50296,
|
1560 |
+
"<|mn|>": 50314,
|
1561 |
+
"<|mr|>": 50320,
|
1562 |
+
"<|ms|>": 50282,
|
1563 |
+
"<|mt|>": 50343,
|
1564 |
+
"<|my|>": 50346,
|
1565 |
+
"<|ne|>": 50313,
|
1566 |
+
"<|nl|>": 50271,
|
1567 |
+
"<|nn|>": 50342,
|
1568 |
+
"<|nospeech|>": 50363,
|
1569 |
+
"<|notimestamps|>": 50364,
|
1570 |
+
"<|no|>": 50288,
|
1571 |
+
"<|oc|>": 50328,
|
1572 |
+
"<|pa|>": 50321,
|
1573 |
+
"<|pl|>": 50269,
|
1574 |
+
"<|ps|>": 50340,
|
1575 |
+
"<|pt|>": 50267,
|
1576 |
+
"<|ro|>": 50284,
|
1577 |
+
"<|ru|>": 50263,
|
1578 |
+
"<|sa|>": 50344,
|
1579 |
+
"<|sd|>": 50332,
|
1580 |
+
"<|si|>": 50322,
|
1581 |
+
"<|sk|>": 50298,
|
1582 |
+
"<|sl|>": 50305,
|
1583 |
+
"<|sn|>": 50324,
|
1584 |
+
"<|so|>": 50326,
|
1585 |
+
"<|sq|>": 50317,
|
1586 |
+
"<|sr|>": 50303,
|
1587 |
+
"<|startoflm|>": 50361,
|
1588 |
+
"<|startofprev|>": 50362,
|
1589 |
+
"<|startoftranscript|>": 50258,
|
1590 |
+
"<|su|>": 50357,
|
1591 |
+
"<|sv|>": 50273,
|
1592 |
+
"<|sw|>": 50318,
|
1593 |
+
"<|ta|>": 50287,
|
1594 |
+
"<|te|>": 50299,
|
1595 |
+
"<|tg|>": 50331,
|
1596 |
+
"<|th|>": 50289,
|
1597 |
+
"<|tk|>": 50341,
|
1598 |
+
"<|tl|>": 50348,
|
1599 |
+
"<|transcribe|>": 50360,
|
1600 |
+
"<|translate|>": 50359,
|
1601 |
+
"<|tr|>": 50268,
|
1602 |
+
"<|tt|>": 50351,
|
1603 |
+
"<|uk|>": 50280,
|
1604 |
+
"<|ur|>": 50290,
|
1605 |
+
"<|uz|>": 50337,
|
1606 |
+
"<|vi|>": 50278,
|
1607 |
+
"<|yi|>": 50335,
|
1608 |
+
"<|yo|>": 50325,
|
1609 |
+
"<|yue|>": 50358,
|
1610 |
+
"<|zh|>": 50260
|
1611 |
+
}
|
checkpoint-5000-epoch-0/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df305b3f7009eab623e7fecd9ce8e65b95fc4a2e34b3175d7ae6b0441fc8ea0f
|
3 |
+
size 3025686376
|
checkpoint-5000-epoch-0/model_1.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e39aef3bbaf23383a1b2b68ca0e62361f1499dcfae56a696dbe942557f54e9e1
|
3 |
+
size 4361070048
|
checkpoint-5000-epoch-0/optimizer.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:beba5fa101effd6ef750b055ff05941f381beb7237601eef8ed9ce5d38068f57
|
3 |
+
size 955539578
|
checkpoint-5000-epoch-0/random_states_0.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:63ea449c3ffff0f8fd79b27243bc394123729a01294a9f217d4048bd8a330d73
|
3 |
+
size 14604
|
checkpoint-5000-epoch-0/random_states_1.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:761015fa832e44df5631ee8ab425a10288697c42533512b467f11930c6b5effe
|
3 |
+
size 14604
|
checkpoint-5000-epoch-0/scheduler.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:155f9eb76013e7e018690f6bb69927b931a2be52f0f4eedf79eb9dfeecc9f35d
|
3 |
+
size 1064
|
config.json
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "GalaktischeGurke/swhisper_large_8552",
|
3 |
+
"activation_dropout": 0.0,
|
4 |
+
"activation_function": "gelu",
|
5 |
+
"apply_spec_augment": false,
|
6 |
+
"architectures": [
|
7 |
+
"WhisperForConditionalGeneration"
|
8 |
+
],
|
9 |
+
"attention_dropout": 0.0,
|
10 |
+
"begin_suppress_tokens": [
|
11 |
+
220,
|
12 |
+
50257
|
13 |
+
],
|
14 |
+
"bos_token_id": 50257,
|
15 |
+
"classifier_proj_size": 256,
|
16 |
+
"d_model": 1280,
|
17 |
+
"decoder_attention_heads": 20,
|
18 |
+
"decoder_ffn_dim": 5120,
|
19 |
+
"decoder_layerdrop": 0.0,
|
20 |
+
"decoder_layers": 2,
|
21 |
+
"decoder_start_token_id": 50258,
|
22 |
+
"dropout": 0.0,
|
23 |
+
"encoder_attention_heads": 20,
|
24 |
+
"encoder_ffn_dim": 5120,
|
25 |
+
"encoder_layerdrop": 0.0,
|
26 |
+
"encoder_layers": 32,
|
27 |
+
"eos_token_id": 50257,
|
28 |
+
"init_std": 0.02,
|
29 |
+
"is_encoder_decoder": true,
|
30 |
+
"mask_feature_length": 10,
|
31 |
+
"mask_feature_min_masks": 0,
|
32 |
+
"mask_feature_prob": 0.0,
|
33 |
+
"mask_time_length": 10,
|
34 |
+
"mask_time_min_masks": 2,
|
35 |
+
"mask_time_prob": 0.05,
|
36 |
+
"max_length": 448,
|
37 |
+
"max_source_positions": 1500,
|
38 |
+
"max_target_positions": 448,
|
39 |
+
"median_filter_width": 7,
|
40 |
+
"model_type": "whisper",
|
41 |
+
"num_hidden_layers": 32,
|
42 |
+
"num_mel_bins": 128,
|
43 |
+
"pad_token_id": 50256,
|
44 |
+
"scale_embedding": false,
|
45 |
+
"torch_dtype": "float32",
|
46 |
+
"transformers_version": "4.41.2",
|
47 |
+
"use_cache": true,
|
48 |
+
"use_weighted_layer_sum": false,
|
49 |
+
"vocab_size": 51866
|
50 |
+
}
|
core
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2542a5c60b340914a9a695344dc58cca5a8ed2ca317dcac77bcc49bda0fc93b2
|
3 |
+
size 4742078464
|
create_student_model.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Initialise a student Whisper model from a pre-trained teacher model for
|
18 |
+
teacher-student distillation.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import argparse
|
22 |
+
import copy
|
23 |
+
import logging
|
24 |
+
|
25 |
+
import numpy as np
|
26 |
+
import torch
|
27 |
+
from transformers import GenerationConfig, WhisperForConditionalGeneration, WhisperProcessor
|
28 |
+
|
29 |
+
|
30 |
+
logger = logging.getLogger(__name__)
|
31 |
+
|
32 |
+
|
33 |
+
def parse_args():
|
34 |
+
parser = argparse.ArgumentParser(
|
35 |
+
description="Initialise a student Whisper model from a teacher model, copying the relevant layer weights and adjusting the processor as necessary."
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--teacher_checkpoint",
|
39 |
+
type=str,
|
40 |
+
required=True,
|
41 |
+
help="The HF Hub ID of the teacher checkpoint.",
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
"--subfolder",
|
45 |
+
type=str,
|
46 |
+
default="",
|
47 |
+
help="In case the relevant teacher weights are located inside a subfolder of the model repo on huggingface.co, you "
|
48 |
+
"can specify the folder name here.",
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
"--encoder_layers",
|
52 |
+
type=int,
|
53 |
+
default=None,
|
54 |
+
help="Number of encoder layers to use in the student model. Defaults to all layers from the teacher.",
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--decoder_layers",
|
58 |
+
type=int,
|
59 |
+
default=2,
|
60 |
+
help="Number of decoder layers to use in the student model. Defaults to 2 layers.",
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--decoder_layers_numbers",
|
64 |
+
type=int,
|
65 |
+
nargs="*",
|
66 |
+
help="Layers numbers of the decoder teacher to use in the student model. Defaults to None, equivalent to taking first and last layer (and equivalent to `--decoder_layers_numbers 0 -1`).",
|
67 |
+
)
|
68 |
+
parser.add_argument(
|
69 |
+
"--save_dir",
|
70 |
+
type=str,
|
71 |
+
required=True,
|
72 |
+
help="Where to save the student weights and processor.",
|
73 |
+
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--push_to_hub",
|
76 |
+
type=bool,
|
77 |
+
required=False,
|
78 |
+
default=False,
|
79 |
+
help="Whether to push the student weights and processor to the Hub.",
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--cache_dir",
|
83 |
+
type=str,
|
84 |
+
default=None,
|
85 |
+
help="Where to store the pretrained models downloaded from huggingface.co",
|
86 |
+
)
|
87 |
+
|
88 |
+
args = parser.parse_args()
|
89 |
+
return args
|
90 |
+
|
91 |
+
|
92 |
+
def init_student_model_from_teacher(
|
93 |
+
teacher_checkpoint,
|
94 |
+
encoder_layers=None,
|
95 |
+
decoder_layers=2,
|
96 |
+
decoder_layers_numbers=None,
|
97 |
+
save_dir=None,
|
98 |
+
push_to_hub=None,
|
99 |
+
cache_dir=None,
|
100 |
+
subfolder="",
|
101 |
+
):
|
102 |
+
if decoder_layers_numbers is not None and len(decoder_layers_numbers) != decoder_layers:
|
103 |
+
raise ValueError(
|
104 |
+
f"Got {len(decoder_layers_numbers)} layers number for {decoder_layers} decoder layers."
|
105 |
+
)
|
106 |
+
|
107 |
+
teacher_model = WhisperForConditionalGeneration.from_pretrained(
|
108 |
+
teacher_checkpoint,
|
109 |
+
cache_dir=cache_dir,
|
110 |
+
subfolder=subfolder,
|
111 |
+
low_cpu_mem_usage=True,
|
112 |
+
)
|
113 |
+
processor = WhisperProcessor.from_pretrained(teacher_checkpoint)
|
114 |
+
generation_config = GenerationConfig.from_pretrained(teacher_checkpoint)
|
115 |
+
generation_config.forced_decoder_ids = None
|
116 |
+
|
117 |
+
teacher_config = teacher_model.config
|
118 |
+
teacher_encoder_layers = teacher_config.encoder_layers
|
119 |
+
teacher_decoder_layers = teacher_config.decoder_layers
|
120 |
+
|
121 |
+
student_config = copy.deepcopy(teacher_config)
|
122 |
+
student_config.update(
|
123 |
+
{
|
124 |
+
"encoder_layers": encoder_layers if encoder_layers is not None else teacher_encoder_layers,
|
125 |
+
"decoder_layers": decoder_layers,
|
126 |
+
}
|
127 |
+
)
|
128 |
+
|
129 |
+
encoder_mapping = np.linspace(0, teacher_encoder_layers - 1, student_config.encoder_layers, dtype=int)
|
130 |
+
encoder_mapping[-1] = teacher_encoder_layers - 1
|
131 |
+
|
132 |
+
encoder_map = {}
|
133 |
+
for student_layer, teacher_layer in enumerate(encoder_mapping):
|
134 |
+
encoder_map[teacher_layer] = student_layer
|
135 |
+
|
136 |
+
if decoder_layers_numbers is None:
|
137 |
+
decoder_mapping = np.linspace(0, teacher_decoder_layers - 1, student_config.decoder_layers, dtype=int)
|
138 |
+
decoder_mapping[-1] = teacher_decoder_layers - 1
|
139 |
+
else:
|
140 |
+
decoder_mapping = decoder_layers_numbers
|
141 |
+
|
142 |
+
decoder_map = {}
|
143 |
+
for student_layer, teacher_layer in enumerate(decoder_mapping):
|
144 |
+
decoder_map[teacher_layer] = student_layer
|
145 |
+
|
146 |
+
# init the student params from the teacher model
|
147 |
+
student_model = WhisperForConditionalGeneration(student_config)
|
148 |
+
missing_keys, unexpected_keys = student_model.load_state_dict(teacher_model.state_dict(), strict=False)
|
149 |
+
if len(missing_keys) > 0:
|
150 |
+
raise RuntimeError(
|
151 |
+
"Error(s) in loading state_dict for WhisperForConditionalGeneration. \n"
|
152 |
+
f"Missing key(s) in state_dict: {missing_keys}"
|
153 |
+
)
|
154 |
+
if decoder_layers == teacher_decoder_layers:
|
155 |
+
decoder_keys = [key for key in unexpected_keys if "model.decoder.layers" in key]
|
156 |
+
if len(decoder_keys) > 0:
|
157 |
+
raise RuntimeError(
|
158 |
+
"Error(s) in loading state_dict for WhisperForConditionalGeneration. \n"
|
159 |
+
f"Unexpected key(s) in state_dict: {decoder_keys}"
|
160 |
+
)
|
161 |
+
if encoder_layers == teacher_encoder_layers:
|
162 |
+
encoder_keys = [key for key in unexpected_keys if "model.encoder.layers" in key]
|
163 |
+
if len(encoder_keys) > 0:
|
164 |
+
raise RuntimeError(
|
165 |
+
"Error(s) in loading state_dict for WhisperForConditionalGeneration. \n"
|
166 |
+
f"Unexpected key(s) in state_dict: {encoder_keys}"
|
167 |
+
)
|
168 |
+
|
169 |
+
for layer in range(teacher_decoder_layers):
|
170 |
+
if layer in decoder_map:
|
171 |
+
# re-introduce pre-defined layers from the teacher
|
172 |
+
student_model.model.decoder.layers[decoder_map[layer]].load_state_dict(
|
173 |
+
teacher_model.model.decoder.layers[layer].state_dict()
|
174 |
+
)
|
175 |
+
|
176 |
+
if encoder_layers is not None:
|
177 |
+
for layer in range(teacher_encoder_layers):
|
178 |
+
if layer in encoder_map:
|
179 |
+
# re-introduce pre-defined layers from the teacher
|
180 |
+
student_model.model.encoder.layers[encoder_map[layer]].load_state_dict(
|
181 |
+
teacher_model.model.encoder.layers[layer].state_dict()
|
182 |
+
)
|
183 |
+
|
184 |
+
# remove the teacher params and model
|
185 |
+
del teacher_model
|
186 |
+
|
187 |
+
# save the converted weights and model
|
188 |
+
if save_dir is not None:
|
189 |
+
student_model.save_pretrained(save_dir)
|
190 |
+
# we also need to correctly save the processor and generation config
|
191 |
+
processor.save_pretrained(save_dir)
|
192 |
+
generation_config.save_pretrained(save_dir)
|
193 |
+
|
194 |
+
# check we can do a forward pass with the saved model - first load the weights and processor
|
195 |
+
logger.info("Checking we can load the saved model...")
|
196 |
+
student_model = WhisperForConditionalGeneration.from_pretrained(
|
197 |
+
save_dir,
|
198 |
+
low_cpu_mem_usage=True,
|
199 |
+
)
|
200 |
+
processor = WhisperProcessor.from_pretrained(save_dir)
|
201 |
+
|
202 |
+
# define some random inputs
|
203 |
+
input_features = processor(np.ones(16000), sampling_rate=16000, return_tensors="pt").input_features
|
204 |
+
decoder_start_token_id = student_model.config.decoder_start_token_id
|
205 |
+
decoder_input_ids = torch.ones((input_features.shape[0], 1), dtype=torch.long) * decoder_start_token_id
|
206 |
+
|
207 |
+
# do a forward pass - outputs will be gibberish for the initialised model so we can't check them
|
208 |
+
# but we make can sure the model runs as expected
|
209 |
+
logger.info("Checking we can run the converted model forward...")
|
210 |
+
_ = student_model(input_features, decoder_input_ids=decoder_input_ids).logits
|
211 |
+
logger.info("Conversion successful!")
|
212 |
+
|
213 |
+
if push_to_hub:
|
214 |
+
student_model.push_to_hub(save_dir)
|
215 |
+
processor.push_to_hub(save_dir)
|
216 |
+
generation_config.push_to_hub(save_dir)
|
217 |
+
|
218 |
+
|
219 |
+
if __name__ == "__main__":
|
220 |
+
args = parse_args()
|
221 |
+
|
222 |
+
init_student_model_from_teacher(
|
223 |
+
teacher_checkpoint=args.teacher_checkpoint,
|
224 |
+
encoder_layers=args.encoder_layers,
|
225 |
+
decoder_layers=args.decoder_layers,
|
226 |
+
decoder_layers_numbers=args.decoder_layers_numbers,
|
227 |
+
save_dir=args.save_dir,
|
228 |
+
push_to_hub=args.push_to_hub,
|
229 |
+
cache_dir=args.cache_dir,
|
230 |
+
subfolder=args.subfolder,
|
231 |
+
)
|
distil-large-v3-init/added_tokens.json
ADDED
@@ -0,0 +1,1611 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"<|0.00|>": 50365,
|
3 |
+
"<|0.02|>": 50366,
|
4 |
+
"<|0.04|>": 50367,
|
5 |
+
"<|0.06|>": 50368,
|
6 |
+
"<|0.08|>": 50369,
|
7 |
+
"<|0.10|>": 50370,
|
8 |
+
"<|0.12|>": 50371,
|
9 |
+
"<|0.14|>": 50372,
|
10 |
+
"<|0.16|>": 50373,
|
11 |
+
"<|0.18|>": 50374,
|
12 |
+
"<|0.20|>": 50375,
|
13 |
+
"<|0.22|>": 50376,
|
14 |
+
"<|0.24|>": 50377,
|
15 |
+
"<|0.26|>": 50378,
|
16 |
+
"<|0.28|>": 50379,
|
17 |
+
"<|0.30|>": 50380,
|
18 |
+
"<|0.32|>": 50381,
|
19 |
+
"<|0.34|>": 50382,
|
20 |
+
"<|0.36|>": 50383,
|
21 |
+
"<|0.38|>": 50384,
|
22 |
+
"<|0.40|>": 50385,
|
23 |
+
"<|0.42|>": 50386,
|
24 |
+
"<|0.44|>": 50387,
|
25 |
+
"<|0.46|>": 50388,
|
26 |
+
"<|0.48|>": 50389,
|
27 |
+
"<|0.50|>": 50390,
|
28 |
+
"<|0.52|>": 50391,
|
29 |
+
"<|0.54|>": 50392,
|
30 |
+
"<|0.56|>": 50393,
|
31 |
+
"<|0.58|>": 50394,
|
32 |
+
"<|0.60|>": 50395,
|
33 |
+
"<|0.62|>": 50396,
|
34 |
+
"<|0.64|>": 50397,
|
35 |
+
"<|0.66|>": 50398,
|
36 |
+
"<|0.68|>": 50399,
|
37 |
+
"<|0.70|>": 50400,
|
38 |
+
"<|0.72|>": 50401,
|
39 |
+
"<|0.74|>": 50402,
|
40 |
+
"<|0.76|>": 50403,
|
41 |
+
"<|0.78|>": 50404,
|
42 |
+
"<|0.80|>": 50405,
|
43 |
+
"<|0.82|>": 50406,
|
44 |
+
"<|0.84|>": 50407,
|
45 |
+
"<|0.86|>": 50408,
|
46 |
+
"<|0.88|>": 50409,
|
47 |
+
"<|0.90|>": 50410,
|
48 |
+
"<|0.92|>": 50411,
|
49 |
+
"<|0.94|>": 50412,
|
50 |
+
"<|0.96|>": 50413,
|
51 |
+
"<|0.98|>": 50414,
|
52 |
+
"<|1.00|>": 50415,
|
53 |
+
"<|1.02|>": 50416,
|
54 |
+
"<|1.04|>": 50417,
|
55 |
+
"<|1.06|>": 50418,
|
56 |
+
"<|1.08|>": 50419,
|
57 |
+
"<|1.10|>": 50420,
|
58 |
+
"<|1.12|>": 50421,
|
59 |
+
"<|1.14|>": 50422,
|
60 |
+
"<|1.16|>": 50423,
|
61 |
+
"<|1.18|>": 50424,
|
62 |
+
"<|1.20|>": 50425,
|
63 |
+
"<|1.22|>": 50426,
|
64 |
+
"<|1.24|>": 50427,
|
65 |
+
"<|1.26|>": 50428,
|
66 |
+
"<|1.28|>": 50429,
|
67 |
+
"<|1.30|>": 50430,
|
68 |
+
"<|1.32|>": 50431,
|
69 |
+
"<|1.34|>": 50432,
|
70 |
+
"<|1.36|>": 50433,
|
71 |
+
"<|1.38|>": 50434,
|
72 |
+
"<|1.40|>": 50435,
|
73 |
+
"<|1.42|>": 50436,
|
74 |
+
"<|1.44|>": 50437,
|
75 |
+
"<|1.46|>": 50438,
|
76 |
+
"<|1.48|>": 50439,
|
77 |
+
"<|1.50|>": 50440,
|
78 |
+
"<|1.52|>": 50441,
|
79 |
+
"<|1.54|>": 50442,
|
80 |
+
"<|1.56|>": 50443,
|
81 |
+
"<|1.58|>": 50444,
|
82 |
+
"<|1.60|>": 50445,
|
83 |
+
"<|1.62|>": 50446,
|
84 |
+
"<|1.64|>": 50447,
|
85 |
+
"<|1.66|>": 50448,
|
86 |
+
"<|1.68|>": 50449,
|
87 |
+
"<|1.70|>": 50450,
|
88 |
+
"<|1.72|>": 50451,
|
89 |
+
"<|1.74|>": 50452,
|
90 |
+
"<|1.76|>": 50453,
|
91 |
+
"<|1.78|>": 50454,
|
92 |
+
"<|1.80|>": 50455,
|
93 |
+
"<|1.82|>": 50456,
|
94 |
+
"<|1.84|>": 50457,
|
95 |
+
"<|1.86|>": 50458,
|
96 |
+
"<|1.88|>": 50459,
|
97 |
+
"<|1.90|>": 50460,
|
98 |
+
"<|1.92|>": 50461,
|
99 |
+
"<|1.94|>": 50462,
|
100 |
+
"<|1.96|>": 50463,
|
101 |
+
"<|1.98|>": 50464,
|
102 |
+
"<|10.00|>": 50865,
|
103 |
+
"<|10.02|>": 50866,
|
104 |
+
"<|10.04|>": 50867,
|
105 |
+
"<|10.06|>": 50868,
|
106 |
+
"<|10.08|>": 50869,
|
107 |
+
"<|10.10|>": 50870,
|
108 |
+
"<|10.12|>": 50871,
|
109 |
+
"<|10.14|>": 50872,
|
110 |
+
"<|10.16|>": 50873,
|
111 |
+
"<|10.18|>": 50874,
|
112 |
+
"<|10.20|>": 50875,
|
113 |
+
"<|10.22|>": 50876,
|
114 |
+
"<|10.24|>": 50877,
|
115 |
+
"<|10.26|>": 50878,
|
116 |
+
"<|10.28|>": 50879,
|
117 |
+
"<|10.30|>": 50880,
|
118 |
+
"<|10.32|>": 50881,
|
119 |
+
"<|10.34|>": 50882,
|
120 |
+
"<|10.36|>": 50883,
|
121 |
+
"<|10.38|>": 50884,
|
122 |
+
"<|10.40|>": 50885,
|
123 |
+
"<|10.42|>": 50886,
|
124 |
+
"<|10.44|>": 50887,
|
125 |
+
"<|10.46|>": 50888,
|
126 |
+
"<|10.48|>": 50889,
|
127 |
+
"<|10.50|>": 50890,
|
128 |
+
"<|10.52|>": 50891,
|
129 |
+
"<|10.54|>": 50892,
|
130 |
+
"<|10.56|>": 50893,
|
131 |
+
"<|10.58|>": 50894,
|
132 |
+
"<|10.60|>": 50895,
|
133 |
+
"<|10.62|>": 50896,
|
134 |
+
"<|10.64|>": 50897,
|
135 |
+
"<|10.66|>": 50898,
|
136 |
+
"<|10.68|>": 50899,
|
137 |
+
"<|10.70|>": 50900,
|
138 |
+
"<|10.72|>": 50901,
|
139 |
+
"<|10.74|>": 50902,
|
140 |
+
"<|10.76|>": 50903,
|
141 |
+
"<|10.78|>": 50904,
|
142 |
+
"<|10.80|>": 50905,
|
143 |
+
"<|10.82|>": 50906,
|
144 |
+
"<|10.84|>": 50907,
|
145 |
+
"<|10.86|>": 50908,
|
146 |
+
"<|10.88|>": 50909,
|
147 |
+
"<|10.90|>": 50910,
|
148 |
+
"<|10.92|>": 50911,
|
149 |
+
"<|10.94|>": 50912,
|
150 |
+
"<|10.96|>": 50913,
|
151 |
+
"<|10.98|>": 50914,
|
152 |
+
"<|11.00|>": 50915,
|
153 |
+
"<|11.02|>": 50916,
|
154 |
+
"<|11.04|>": 50917,
|
155 |
+
"<|11.06|>": 50918,
|
156 |
+
"<|11.08|>": 50919,
|
157 |
+
"<|11.10|>": 50920,
|
158 |
+
"<|11.12|>": 50921,
|
159 |
+
"<|11.14|>": 50922,
|
160 |
+
"<|11.16|>": 50923,
|
161 |
+
"<|11.18|>": 50924,
|
162 |
+
"<|11.20|>": 50925,
|
163 |
+
"<|11.22|>": 50926,
|
164 |
+
"<|11.24|>": 50927,
|
165 |
+
"<|11.26|>": 50928,
|
166 |
+
"<|11.28|>": 50929,
|
167 |
+
"<|11.30|>": 50930,
|
168 |
+
"<|11.32|>": 50931,
|
169 |
+
"<|11.34|>": 50932,
|
170 |
+
"<|11.36|>": 50933,
|
171 |
+
"<|11.38|>": 50934,
|
172 |
+
"<|11.40|>": 50935,
|
173 |
+
"<|11.42|>": 50936,
|
174 |
+
"<|11.44|>": 50937,
|
175 |
+
"<|11.46|>": 50938,
|
176 |
+
"<|11.48|>": 50939,
|
177 |
+
"<|11.50|>": 50940,
|
178 |
+
"<|11.52|>": 50941,
|
179 |
+
"<|11.54|>": 50942,
|
180 |
+
"<|11.56|>": 50943,
|
181 |
+
"<|11.58|>": 50944,
|
182 |
+
"<|11.60|>": 50945,
|
183 |
+
"<|11.62|>": 50946,
|
184 |
+
"<|11.64|>": 50947,
|
185 |
+
"<|11.66|>": 50948,
|
186 |
+
"<|11.68|>": 50949,
|
187 |
+
"<|11.70|>": 50950,
|
188 |
+
"<|11.72|>": 50951,
|
189 |
+
"<|11.74|>": 50952,
|
190 |
+
"<|11.76|>": 50953,
|
191 |
+
"<|11.78|>": 50954,
|
192 |
+
"<|11.80|>": 50955,
|
193 |
+
"<|11.82|>": 50956,
|
194 |
+
"<|11.84|>": 50957,
|
195 |
+
"<|11.86|>": 50958,
|
196 |
+
"<|11.88|>": 50959,
|
197 |
+
"<|11.90|>": 50960,
|
198 |
+
"<|11.92|>": 50961,
|
199 |
+
"<|11.94|>": 50962,
|
200 |
+
"<|11.96|>": 50963,
|
201 |
+
"<|11.98|>": 50964,
|
202 |
+
"<|12.00|>": 50965,
|
203 |
+
"<|12.02|>": 50966,
|
204 |
+
"<|12.04|>": 50967,
|
205 |
+
"<|12.06|>": 50968,
|
206 |
+
"<|12.08|>": 50969,
|
207 |
+
"<|12.10|>": 50970,
|
208 |
+
"<|12.12|>": 50971,
|
209 |
+
"<|12.14|>": 50972,
|
210 |
+
"<|12.16|>": 50973,
|
211 |
+
"<|12.18|>": 50974,
|
212 |
+
"<|12.20|>": 50975,
|
213 |
+
"<|12.22|>": 50976,
|
214 |
+
"<|12.24|>": 50977,
|
215 |
+
"<|12.26|>": 50978,
|
216 |
+
"<|12.28|>": 50979,
|
217 |
+
"<|12.30|>": 50980,
|
218 |
+
"<|12.32|>": 50981,
|
219 |
+
"<|12.34|>": 50982,
|
220 |
+
"<|12.36|>": 50983,
|
221 |
+
"<|12.38|>": 50984,
|
222 |
+
"<|12.40|>": 50985,
|
223 |
+
"<|12.42|>": 50986,
|
224 |
+
"<|12.44|>": 50987,
|
225 |
+
"<|12.46|>": 50988,
|
226 |
+
"<|12.48|>": 50989,
|
227 |
+
"<|12.50|>": 50990,
|
228 |
+
"<|12.52|>": 50991,
|
229 |
+
"<|12.54|>": 50992,
|
230 |
+
"<|12.56|>": 50993,
|
231 |
+
"<|12.58|>": 50994,
|
232 |
+
"<|12.60|>": 50995,
|
233 |
+
"<|12.62|>": 50996,
|
234 |
+
"<|12.64|>": 50997,
|
235 |
+
"<|12.66|>": 50998,
|
236 |
+
"<|12.68|>": 50999,
|
237 |
+
"<|12.70|>": 51000,
|
238 |
+
"<|12.72|>": 51001,
|
239 |
+
"<|12.74|>": 51002,
|
240 |
+
"<|12.76|>": 51003,
|
241 |
+
"<|12.78|>": 51004,
|
242 |
+
"<|12.80|>": 51005,
|
243 |
+
"<|12.82|>": 51006,
|
244 |
+
"<|12.84|>": 51007,
|
245 |
+
"<|12.86|>": 51008,
|
246 |
+
"<|12.88|>": 51009,
|
247 |
+
"<|12.90|>": 51010,
|
248 |
+
"<|12.92|>": 51011,
|
249 |
+
"<|12.94|>": 51012,
|
250 |
+
"<|12.96|>": 51013,
|
251 |
+
"<|12.98|>": 51014,
|
252 |
+
"<|13.00|>": 51015,
|
253 |
+
"<|13.02|>": 51016,
|
254 |
+
"<|13.04|>": 51017,
|
255 |
+
"<|13.06|>": 51018,
|
256 |
+
"<|13.08|>": 51019,
|
257 |
+
"<|13.10|>": 51020,
|
258 |
+
"<|13.12|>": 51021,
|
259 |
+
"<|13.14|>": 51022,
|
260 |
+
"<|13.16|>": 51023,
|
261 |
+
"<|13.18|>": 51024,
|
262 |
+
"<|13.20|>": 51025,
|
263 |
+
"<|13.22|>": 51026,
|
264 |
+
"<|13.24|>": 51027,
|
265 |
+
"<|13.26|>": 51028,
|
266 |
+
"<|13.28|>": 51029,
|
267 |
+
"<|13.30|>": 51030,
|
268 |
+
"<|13.32|>": 51031,
|
269 |
+
"<|13.34|>": 51032,
|
270 |
+
"<|13.36|>": 51033,
|
271 |
+
"<|13.38|>": 51034,
|
272 |
+
"<|13.40|>": 51035,
|
273 |
+
"<|13.42|>": 51036,
|
274 |
+
"<|13.44|>": 51037,
|
275 |
+
"<|13.46|>": 51038,
|
276 |
+
"<|13.48|>": 51039,
|
277 |
+
"<|13.50|>": 51040,
|
278 |
+
"<|13.52|>": 51041,
|
279 |
+
"<|13.54|>": 51042,
|
280 |
+
"<|13.56|>": 51043,
|
281 |
+
"<|13.58|>": 51044,
|
282 |
+
"<|13.60|>": 51045,
|
283 |
+
"<|13.62|>": 51046,
|
284 |
+
"<|13.64|>": 51047,
|
285 |
+
"<|13.66|>": 51048,
|
286 |
+
"<|13.68|>": 51049,
|
287 |
+
"<|13.70|>": 51050,
|
288 |
+
"<|13.72|>": 51051,
|
289 |
+
"<|13.74|>": 51052,
|
290 |
+
"<|13.76|>": 51053,
|
291 |
+
"<|13.78|>": 51054,
|
292 |
+
"<|13.80|>": 51055,
|
293 |
+
"<|13.82|>": 51056,
|
294 |
+
"<|13.84|>": 51057,
|
295 |
+
"<|13.86|>": 51058,
|
296 |
+
"<|13.88|>": 51059,
|
297 |
+
"<|13.90|>": 51060,
|
298 |
+
"<|13.92|>": 51061,
|
299 |
+
"<|13.94|>": 51062,
|
300 |
+
"<|13.96|>": 51063,
|
301 |
+
"<|13.98|>": 51064,
|
302 |
+
"<|14.00|>": 51065,
|
303 |
+
"<|14.02|>": 51066,
|
304 |
+
"<|14.04|>": 51067,
|
305 |
+
"<|14.06|>": 51068,
|
306 |
+
"<|14.08|>": 51069,
|
307 |
+
"<|14.10|>": 51070,
|
308 |
+
"<|14.12|>": 51071,
|
309 |
+
"<|14.14|>": 51072,
|
310 |
+
"<|14.16|>": 51073,
|
311 |
+
"<|14.18|>": 51074,
|
312 |
+
"<|14.20|>": 51075,
|
313 |
+
"<|14.22|>": 51076,
|
314 |
+
"<|14.24|>": 51077,
|
315 |
+
"<|14.26|>": 51078,
|
316 |
+
"<|14.28|>": 51079,
|
317 |
+
"<|14.30|>": 51080,
|
318 |
+
"<|14.32|>": 51081,
|
319 |
+
"<|14.34|>": 51082,
|
320 |
+
"<|14.36|>": 51083,
|
321 |
+
"<|14.38|>": 51084,
|
322 |
+
"<|14.40|>": 51085,
|
323 |
+
"<|14.42|>": 51086,
|
324 |
+
"<|14.44|>": 51087,
|
325 |
+
"<|14.46|>": 51088,
|
326 |
+
"<|14.48|>": 51089,
|
327 |
+
"<|14.50|>": 51090,
|
328 |
+
"<|14.52|>": 51091,
|
329 |
+
"<|14.54|>": 51092,
|
330 |
+
"<|14.56|>": 51093,
|
331 |
+
"<|14.58|>": 51094,
|
332 |
+
"<|14.60|>": 51095,
|
333 |
+
"<|14.62|>": 51096,
|
334 |
+
"<|14.64|>": 51097,
|
335 |
+
"<|14.66|>": 51098,
|
336 |
+
"<|14.68|>": 51099,
|
337 |
+
"<|14.70|>": 51100,
|
338 |
+
"<|14.72|>": 51101,
|
339 |
+
"<|14.74|>": 51102,
|
340 |
+
"<|14.76|>": 51103,
|
341 |
+
"<|14.78|>": 51104,
|
342 |
+
"<|14.80|>": 51105,
|
343 |
+
"<|14.82|>": 51106,
|
344 |
+
"<|14.84|>": 51107,
|
345 |
+
"<|14.86|>": 51108,
|
346 |
+
"<|14.88|>": 51109,
|
347 |
+
"<|14.90|>": 51110,
|
348 |
+
"<|14.92|>": 51111,
|
349 |
+
"<|14.94|>": 51112,
|
350 |
+
"<|14.96|>": 51113,
|
351 |
+
"<|14.98|>": 51114,
|
352 |
+
"<|15.00|>": 51115,
|
353 |
+
"<|15.02|>": 51116,
|
354 |
+
"<|15.04|>": 51117,
|
355 |
+
"<|15.06|>": 51118,
|
356 |
+
"<|15.08|>": 51119,
|
357 |
+
"<|15.10|>": 51120,
|
358 |
+
"<|15.12|>": 51121,
|
359 |
+
"<|15.14|>": 51122,
|
360 |
+
"<|15.16|>": 51123,
|
361 |
+
"<|15.18|>": 51124,
|
362 |
+
"<|15.20|>": 51125,
|
363 |
+
"<|15.22|>": 51126,
|
364 |
+
"<|15.24|>": 51127,
|
365 |
+
"<|15.26|>": 51128,
|
366 |
+
"<|15.28|>": 51129,
|
367 |
+
"<|15.30|>": 51130,
|
368 |
+
"<|15.32|>": 51131,
|
369 |
+
"<|15.34|>": 51132,
|
370 |
+
"<|15.36|>": 51133,
|
371 |
+
"<|15.38|>": 51134,
|
372 |
+
"<|15.40|>": 51135,
|
373 |
+
"<|15.42|>": 51136,
|
374 |
+
"<|15.44|>": 51137,
|
375 |
+
"<|15.46|>": 51138,
|
376 |
+
"<|15.48|>": 51139,
|
377 |
+
"<|15.50|>": 51140,
|
378 |
+
"<|15.52|>": 51141,
|
379 |
+
"<|15.54|>": 51142,
|
380 |
+
"<|15.56|>": 51143,
|
381 |
+
"<|15.58|>": 51144,
|
382 |
+
"<|15.60|>": 51145,
|
383 |
+
"<|15.62|>": 51146,
|
384 |
+
"<|15.64|>": 51147,
|
385 |
+
"<|15.66|>": 51148,
|
386 |
+
"<|15.68|>": 51149,
|
387 |
+
"<|15.70|>": 51150,
|
388 |
+
"<|15.72|>": 51151,
|
389 |
+
"<|15.74|>": 51152,
|
390 |
+
"<|15.76|>": 51153,
|
391 |
+
"<|15.78|>": 51154,
|
392 |
+
"<|15.80|>": 51155,
|
393 |
+
"<|15.82|>": 51156,
|
394 |
+
"<|15.84|>": 51157,
|
395 |
+
"<|15.86|>": 51158,
|
396 |
+
"<|15.88|>": 51159,
|
397 |
+
"<|15.90|>": 51160,
|
398 |
+
"<|15.92|>": 51161,
|
399 |
+
"<|15.94|>": 51162,
|
400 |
+
"<|15.96|>": 51163,
|
401 |
+
"<|15.98|>": 51164,
|
402 |
+
"<|16.00|>": 51165,
|
403 |
+
"<|16.02|>": 51166,
|
404 |
+
"<|16.04|>": 51167,
|
405 |
+
"<|16.06|>": 51168,
|
406 |
+
"<|16.08|>": 51169,
|
407 |
+
"<|16.10|>": 51170,
|
408 |
+
"<|16.12|>": 51171,
|
409 |
+
"<|16.14|>": 51172,
|
410 |
+
"<|16.16|>": 51173,
|
411 |
+
"<|16.18|>": 51174,
|
412 |
+
"<|16.20|>": 51175,
|
413 |
+
"<|16.22|>": 51176,
|
414 |
+
"<|16.24|>": 51177,
|
415 |
+
"<|16.26|>": 51178,
|
416 |
+
"<|16.28|>": 51179,
|
417 |
+
"<|16.30|>": 51180,
|
418 |
+
"<|16.32|>": 51181,
|
419 |
+
"<|16.34|>": 51182,
|
420 |
+
"<|16.36|>": 51183,
|
421 |
+
"<|16.38|>": 51184,
|
422 |
+
"<|16.40|>": 51185,
|
423 |
+
"<|16.42|>": 51186,
|
424 |
+
"<|16.44|>": 51187,
|
425 |
+
"<|16.46|>": 51188,
|
426 |
+
"<|16.48|>": 51189,
|
427 |
+
"<|16.50|>": 51190,
|
428 |
+
"<|16.52|>": 51191,
|
429 |
+
"<|16.54|>": 51192,
|
430 |
+
"<|16.56|>": 51193,
|
431 |
+
"<|16.58|>": 51194,
|
432 |
+
"<|16.60|>": 51195,
|
433 |
+
"<|16.62|>": 51196,
|
434 |
+
"<|16.64|>": 51197,
|
435 |
+
"<|16.66|>": 51198,
|
436 |
+
"<|16.68|>": 51199,
|
437 |
+
"<|16.70|>": 51200,
|
438 |
+
"<|16.72|>": 51201,
|
439 |
+
"<|16.74|>": 51202,
|
440 |
+
"<|16.76|>": 51203,
|
441 |
+
"<|16.78|>": 51204,
|
442 |
+
"<|16.80|>": 51205,
|
443 |
+
"<|16.82|>": 51206,
|
444 |
+
"<|16.84|>": 51207,
|
445 |
+
"<|16.86|>": 51208,
|
446 |
+
"<|16.88|>": 51209,
|
447 |
+
"<|16.90|>": 51210,
|
448 |
+
"<|16.92|>": 51211,
|
449 |
+
"<|16.94|>": 51212,
|
450 |
+
"<|16.96|>": 51213,
|
451 |
+
"<|16.98|>": 51214,
|
452 |
+
"<|17.00|>": 51215,
|
453 |
+
"<|17.02|>": 51216,
|
454 |
+
"<|17.04|>": 51217,
|
455 |
+
"<|17.06|>": 51218,
|
456 |
+
"<|17.08|>": 51219,
|
457 |
+
"<|17.10|>": 51220,
|
458 |
+
"<|17.12|>": 51221,
|
459 |
+
"<|17.14|>": 51222,
|
460 |
+
"<|17.16|>": 51223,
|
461 |
+
"<|17.18|>": 51224,
|
462 |
+
"<|17.20|>": 51225,
|
463 |
+
"<|17.22|>": 51226,
|
464 |
+
"<|17.24|>": 51227,
|
465 |
+
"<|17.26|>": 51228,
|
466 |
+
"<|17.28|>": 51229,
|
467 |
+
"<|17.30|>": 51230,
|
468 |
+
"<|17.32|>": 51231,
|
469 |
+
"<|17.34|>": 51232,
|
470 |
+
"<|17.36|>": 51233,
|
471 |
+
"<|17.38|>": 51234,
|
472 |
+
"<|17.40|>": 51235,
|
473 |
+
"<|17.42|>": 51236,
|
474 |
+
"<|17.44|>": 51237,
|
475 |
+
"<|17.46|>": 51238,
|
476 |
+
"<|17.48|>": 51239,
|
477 |
+
"<|17.50|>": 51240,
|
478 |
+
"<|17.52|>": 51241,
|
479 |
+
"<|17.54|>": 51242,
|
480 |
+
"<|17.56|>": 51243,
|
481 |
+
"<|17.58|>": 51244,
|
482 |
+
"<|17.60|>": 51245,
|
483 |
+
"<|17.62|>": 51246,
|
484 |
+
"<|17.64|>": 51247,
|
485 |
+
"<|17.66|>": 51248,
|
486 |
+
"<|17.68|>": 51249,
|
487 |
+
"<|17.70|>": 51250,
|
488 |
+
"<|17.72|>": 51251,
|
489 |
+
"<|17.74|>": 51252,
|
490 |
+
"<|17.76|>": 51253,
|
491 |
+
"<|17.78|>": 51254,
|
492 |
+
"<|17.80|>": 51255,
|
493 |
+
"<|17.82|>": 51256,
|
494 |
+
"<|17.84|>": 51257,
|
495 |
+
"<|17.86|>": 51258,
|
496 |
+
"<|17.88|>": 51259,
|
497 |
+
"<|17.90|>": 51260,
|
498 |
+
"<|17.92|>": 51261,
|
499 |
+
"<|17.94|>": 51262,
|
500 |
+
"<|17.96|>": 51263,
|
501 |
+
"<|17.98|>": 51264,
|
502 |
+
"<|18.00|>": 51265,
|
503 |
+
"<|18.02|>": 51266,
|
504 |
+
"<|18.04|>": 51267,
|
505 |
+
"<|18.06|>": 51268,
|
506 |
+
"<|18.08|>": 51269,
|
507 |
+
"<|18.10|>": 51270,
|
508 |
+
"<|18.12|>": 51271,
|
509 |
+
"<|18.14|>": 51272,
|
510 |
+
"<|18.16|>": 51273,
|
511 |
+
"<|18.18|>": 51274,
|
512 |
+
"<|18.20|>": 51275,
|
513 |
+
"<|18.22|>": 51276,
|
514 |
+
"<|18.24|>": 51277,
|
515 |
+
"<|18.26|>": 51278,
|
516 |
+
"<|18.28|>": 51279,
|
517 |
+
"<|18.30|>": 51280,
|
518 |
+
"<|18.32|>": 51281,
|
519 |
+
"<|18.34|>": 51282,
|
520 |
+
"<|18.36|>": 51283,
|
521 |
+
"<|18.38|>": 51284,
|
522 |
+
"<|18.40|>": 51285,
|
523 |
+
"<|18.42|>": 51286,
|
524 |
+
"<|18.44|>": 51287,
|
525 |
+
"<|18.46|>": 51288,
|
526 |
+
"<|18.48|>": 51289,
|
527 |
+
"<|18.50|>": 51290,
|
528 |
+
"<|18.52|>": 51291,
|
529 |
+
"<|18.54|>": 51292,
|
530 |
+
"<|18.56|>": 51293,
|
531 |
+
"<|18.58|>": 51294,
|
532 |
+
"<|18.60|>": 51295,
|
533 |
+
"<|18.62|>": 51296,
|
534 |
+
"<|18.64|>": 51297,
|
535 |
+
"<|18.66|>": 51298,
|
536 |
+
"<|18.68|>": 51299,
|
537 |
+
"<|18.70|>": 51300,
|
538 |
+
"<|18.72|>": 51301,
|
539 |
+
"<|18.74|>": 51302,
|
540 |
+
"<|18.76|>": 51303,
|
541 |
+
"<|18.78|>": 51304,
|
542 |
+
"<|18.80|>": 51305,
|
543 |
+
"<|18.82|>": 51306,
|
544 |
+
"<|18.84|>": 51307,
|
545 |
+
"<|18.86|>": 51308,
|
546 |
+
"<|18.88|>": 51309,
|
547 |
+
"<|18.90|>": 51310,
|
548 |
+
"<|18.92|>": 51311,
|
549 |
+
"<|18.94|>": 51312,
|
550 |
+
"<|18.96|>": 51313,
|
551 |
+
"<|18.98|>": 51314,
|
552 |
+
"<|19.00|>": 51315,
|
553 |
+
"<|19.02|>": 51316,
|
554 |
+
"<|19.04|>": 51317,
|
555 |
+
"<|19.06|>": 51318,
|
556 |
+
"<|19.08|>": 51319,
|
557 |
+
"<|19.10|>": 51320,
|
558 |
+
"<|19.12|>": 51321,
|
559 |
+
"<|19.14|>": 51322,
|
560 |
+
"<|19.16|>": 51323,
|
561 |
+
"<|19.18|>": 51324,
|
562 |
+
"<|19.20|>": 51325,
|
563 |
+
"<|19.22|>": 51326,
|
564 |
+
"<|19.24|>": 51327,
|
565 |
+
"<|19.26|>": 51328,
|
566 |
+
"<|19.28|>": 51329,
|
567 |
+
"<|19.30|>": 51330,
|
568 |
+
"<|19.32|>": 51331,
|
569 |
+
"<|19.34|>": 51332,
|
570 |
+
"<|19.36|>": 51333,
|
571 |
+
"<|19.38|>": 51334,
|
572 |
+
"<|19.40|>": 51335,
|
573 |
+
"<|19.42|>": 51336,
|
574 |
+
"<|19.44|>": 51337,
|
575 |
+
"<|19.46|>": 51338,
|
576 |
+
"<|19.48|>": 51339,
|
577 |
+
"<|19.50|>": 51340,
|
578 |
+
"<|19.52|>": 51341,
|
579 |
+
"<|19.54|>": 51342,
|
580 |
+
"<|19.56|>": 51343,
|
581 |
+
"<|19.58|>": 51344,
|
582 |
+
"<|19.60|>": 51345,
|
583 |
+
"<|19.62|>": 51346,
|
584 |
+
"<|19.64|>": 51347,
|
585 |
+
"<|19.66|>": 51348,
|
586 |
+
"<|19.68|>": 51349,
|
587 |
+
"<|19.70|>": 51350,
|
588 |
+
"<|19.72|>": 51351,
|
589 |
+
"<|19.74|>": 51352,
|
590 |
+
"<|19.76|>": 51353,
|
591 |
+
"<|19.78|>": 51354,
|
592 |
+
"<|19.80|>": 51355,
|
593 |
+
"<|19.82|>": 51356,
|
594 |
+
"<|19.84|>": 51357,
|
595 |
+
"<|19.86|>": 51358,
|
596 |
+
"<|19.88|>": 51359,
|
597 |
+
"<|19.90|>": 51360,
|
598 |
+
"<|19.92|>": 51361,
|
599 |
+
"<|19.94|>": 51362,
|
600 |
+
"<|19.96|>": 51363,
|
601 |
+
"<|19.98|>": 51364,
|
602 |
+
"<|2.00|>": 50465,
|
603 |
+
"<|2.02|>": 50466,
|
604 |
+
"<|2.04|>": 50467,
|
605 |
+
"<|2.06|>": 50468,
|
606 |
+
"<|2.08|>": 50469,
|
607 |
+
"<|2.10|>": 50470,
|
608 |
+
"<|2.12|>": 50471,
|
609 |
+
"<|2.14|>": 50472,
|
610 |
+
"<|2.16|>": 50473,
|
611 |
+
"<|2.18|>": 50474,
|
612 |
+
"<|2.20|>": 50475,
|
613 |
+
"<|2.22|>": 50476,
|
614 |
+
"<|2.24|>": 50477,
|
615 |
+
"<|2.26|>": 50478,
|
616 |
+
"<|2.28|>": 50479,
|
617 |
+
"<|2.30|>": 50480,
|
618 |
+
"<|2.32|>": 50481,
|
619 |
+
"<|2.34|>": 50482,
|
620 |
+
"<|2.36|>": 50483,
|
621 |
+
"<|2.38|>": 50484,
|
622 |
+
"<|2.40|>": 50485,
|
623 |
+
"<|2.42|>": 50486,
|
624 |
+
"<|2.44|>": 50487,
|
625 |
+
"<|2.46|>": 50488,
|
626 |
+
"<|2.48|>": 50489,
|
627 |
+
"<|2.50|>": 50490,
|
628 |
+
"<|2.52|>": 50491,
|
629 |
+
"<|2.54|>": 50492,
|
630 |
+
"<|2.56|>": 50493,
|
631 |
+
"<|2.58|>": 50494,
|
632 |
+
"<|2.60|>": 50495,
|
633 |
+
"<|2.62|>": 50496,
|
634 |
+
"<|2.64|>": 50497,
|
635 |
+
"<|2.66|>": 50498,
|
636 |
+
"<|2.68|>": 50499,
|
637 |
+
"<|2.70|>": 50500,
|
638 |
+
"<|2.72|>": 50501,
|
639 |
+
"<|2.74|>": 50502,
|
640 |
+
"<|2.76|>": 50503,
|
641 |
+
"<|2.78|>": 50504,
|
642 |
+
"<|2.80|>": 50505,
|
643 |
+
"<|2.82|>": 50506,
|
644 |
+
"<|2.84|>": 50507,
|
645 |
+
"<|2.86|>": 50508,
|
646 |
+
"<|2.88|>": 50509,
|
647 |
+
"<|2.90|>": 50510,
|
648 |
+
"<|2.92|>": 50511,
|
649 |
+
"<|2.94|>": 50512,
|
650 |
+
"<|2.96|>": 50513,
|
651 |
+
"<|2.98|>": 50514,
|
652 |
+
"<|20.00|>": 51365,
|
653 |
+
"<|20.02|>": 51366,
|
654 |
+
"<|20.04|>": 51367,
|
655 |
+
"<|20.06|>": 51368,
|
656 |
+
"<|20.08|>": 51369,
|
657 |
+
"<|20.10|>": 51370,
|
658 |
+
"<|20.12|>": 51371,
|
659 |
+
"<|20.14|>": 51372,
|
660 |
+
"<|20.16|>": 51373,
|
661 |
+
"<|20.18|>": 51374,
|
662 |
+
"<|20.20|>": 51375,
|
663 |
+
"<|20.22|>": 51376,
|
664 |
+
"<|20.24|>": 51377,
|
665 |
+
"<|20.26|>": 51378,
|
666 |
+
"<|20.28|>": 51379,
|
667 |
+
"<|20.30|>": 51380,
|
668 |
+
"<|20.32|>": 51381,
|
669 |
+
"<|20.34|>": 51382,
|
670 |
+
"<|20.36|>": 51383,
|
671 |
+
"<|20.38|>": 51384,
|
672 |
+
"<|20.40|>": 51385,
|
673 |
+
"<|20.42|>": 51386,
|
674 |
+
"<|20.44|>": 51387,
|
675 |
+
"<|20.46|>": 51388,
|
676 |
+
"<|20.48|>": 51389,
|
677 |
+
"<|20.50|>": 51390,
|
678 |
+
"<|20.52|>": 51391,
|
679 |
+
"<|20.54|>": 51392,
|
680 |
+
"<|20.56|>": 51393,
|
681 |
+
"<|20.58|>": 51394,
|
682 |
+
"<|20.60|>": 51395,
|
683 |
+
"<|20.62|>": 51396,
|
684 |
+
"<|20.64|>": 51397,
|
685 |
+
"<|20.66|>": 51398,
|
686 |
+
"<|20.68|>": 51399,
|
687 |
+
"<|20.70|>": 51400,
|
688 |
+
"<|20.72|>": 51401,
|
689 |
+
"<|20.74|>": 51402,
|
690 |
+
"<|20.76|>": 51403,
|
691 |
+
"<|20.78|>": 51404,
|
692 |
+
"<|20.80|>": 51405,
|
693 |
+
"<|20.82|>": 51406,
|
694 |
+
"<|20.84|>": 51407,
|
695 |
+
"<|20.86|>": 51408,
|
696 |
+
"<|20.88|>": 51409,
|
697 |
+
"<|20.90|>": 51410,
|
698 |
+
"<|20.92|>": 51411,
|
699 |
+
"<|20.94|>": 51412,
|
700 |
+
"<|20.96|>": 51413,
|
701 |
+
"<|20.98|>": 51414,
|
702 |
+
"<|21.00|>": 51415,
|
703 |
+
"<|21.02|>": 51416,
|
704 |
+
"<|21.04|>": 51417,
|
705 |
+
"<|21.06|>": 51418,
|
706 |
+
"<|21.08|>": 51419,
|
707 |
+
"<|21.10|>": 51420,
|
708 |
+
"<|21.12|>": 51421,
|
709 |
+
"<|21.14|>": 51422,
|
710 |
+
"<|21.16|>": 51423,
|
711 |
+
"<|21.18|>": 51424,
|
712 |
+
"<|21.20|>": 51425,
|
713 |
+
"<|21.22|>": 51426,
|
714 |
+
"<|21.24|>": 51427,
|
715 |
+
"<|21.26|>": 51428,
|
716 |
+
"<|21.28|>": 51429,
|
717 |
+
"<|21.30|>": 51430,
|
718 |
+
"<|21.32|>": 51431,
|
719 |
+
"<|21.34|>": 51432,
|
720 |
+
"<|21.36|>": 51433,
|
721 |
+
"<|21.38|>": 51434,
|
722 |
+
"<|21.40|>": 51435,
|
723 |
+
"<|21.42|>": 51436,
|
724 |
+
"<|21.44|>": 51437,
|
725 |
+
"<|21.46|>": 51438,
|
726 |
+
"<|21.48|>": 51439,
|
727 |
+
"<|21.50|>": 51440,
|
728 |
+
"<|21.52|>": 51441,
|
729 |
+
"<|21.54|>": 51442,
|
730 |
+
"<|21.56|>": 51443,
|
731 |
+
"<|21.58|>": 51444,
|
732 |
+
"<|21.60|>": 51445,
|
733 |
+
"<|21.62|>": 51446,
|
734 |
+
"<|21.64|>": 51447,
|
735 |
+
"<|21.66|>": 51448,
|
736 |
+
"<|21.68|>": 51449,
|
737 |
+
"<|21.70|>": 51450,
|
738 |
+
"<|21.72|>": 51451,
|
739 |
+
"<|21.74|>": 51452,
|
740 |
+
"<|21.76|>": 51453,
|
741 |
+
"<|21.78|>": 51454,
|
742 |
+
"<|21.80|>": 51455,
|
743 |
+
"<|21.82|>": 51456,
|
744 |
+
"<|21.84|>": 51457,
|
745 |
+
"<|21.86|>": 51458,
|
746 |
+
"<|21.88|>": 51459,
|
747 |
+
"<|21.90|>": 51460,
|
748 |
+
"<|21.92|>": 51461,
|
749 |
+
"<|21.94|>": 51462,
|
750 |
+
"<|21.96|>": 51463,
|
751 |
+
"<|21.98|>": 51464,
|
752 |
+
"<|22.00|>": 51465,
|
753 |
+
"<|22.02|>": 51466,
|
754 |
+
"<|22.04|>": 51467,
|
755 |
+
"<|22.06|>": 51468,
|
756 |
+
"<|22.08|>": 51469,
|
757 |
+
"<|22.10|>": 51470,
|
758 |
+
"<|22.12|>": 51471,
|
759 |
+
"<|22.14|>": 51472,
|
760 |
+
"<|22.16|>": 51473,
|
761 |
+
"<|22.18|>": 51474,
|
762 |
+
"<|22.20|>": 51475,
|
763 |
+
"<|22.22|>": 51476,
|
764 |
+
"<|22.24|>": 51477,
|
765 |
+
"<|22.26|>": 51478,
|
766 |
+
"<|22.28|>": 51479,
|
767 |
+
"<|22.30|>": 51480,
|
768 |
+
"<|22.32|>": 51481,
|
769 |
+
"<|22.34|>": 51482,
|
770 |
+
"<|22.36|>": 51483,
|
771 |
+
"<|22.38|>": 51484,
|
772 |
+
"<|22.40|>": 51485,
|
773 |
+
"<|22.42|>": 51486,
|
774 |
+
"<|22.44|>": 51487,
|
775 |
+
"<|22.46|>": 51488,
|
776 |
+
"<|22.48|>": 51489,
|
777 |
+
"<|22.50|>": 51490,
|
778 |
+
"<|22.52|>": 51491,
|
779 |
+
"<|22.54|>": 51492,
|
780 |
+
"<|22.56|>": 51493,
|
781 |
+
"<|22.58|>": 51494,
|
782 |
+
"<|22.60|>": 51495,
|
783 |
+
"<|22.62|>": 51496,
|
784 |
+
"<|22.64|>": 51497,
|
785 |
+
"<|22.66|>": 51498,
|
786 |
+
"<|22.68|>": 51499,
|
787 |
+
"<|22.70|>": 51500,
|
788 |
+
"<|22.72|>": 51501,
|
789 |
+
"<|22.74|>": 51502,
|
790 |
+
"<|22.76|>": 51503,
|
791 |
+
"<|22.78|>": 51504,
|
792 |
+
"<|22.80|>": 51505,
|
793 |
+
"<|22.82|>": 51506,
|
794 |
+
"<|22.84|>": 51507,
|
795 |
+
"<|22.86|>": 51508,
|
796 |
+
"<|22.88|>": 51509,
|
797 |
+
"<|22.90|>": 51510,
|
798 |
+
"<|22.92|>": 51511,
|
799 |
+
"<|22.94|>": 51512,
|
800 |
+
"<|22.96|>": 51513,
|
801 |
+
"<|22.98|>": 51514,
|
802 |
+
"<|23.00|>": 51515,
|
803 |
+
"<|23.02|>": 51516,
|
804 |
+
"<|23.04|>": 51517,
|
805 |
+
"<|23.06|>": 51518,
|
806 |
+
"<|23.08|>": 51519,
|
807 |
+
"<|23.10|>": 51520,
|
808 |
+
"<|23.12|>": 51521,
|
809 |
+
"<|23.14|>": 51522,
|
810 |
+
"<|23.16|>": 51523,
|
811 |
+
"<|23.18|>": 51524,
|
812 |
+
"<|23.20|>": 51525,
|
813 |
+
"<|23.22|>": 51526,
|
814 |
+
"<|23.24|>": 51527,
|
815 |
+
"<|23.26|>": 51528,
|
816 |
+
"<|23.28|>": 51529,
|
817 |
+
"<|23.30|>": 51530,
|
818 |
+
"<|23.32|>": 51531,
|
819 |
+
"<|23.34|>": 51532,
|
820 |
+
"<|23.36|>": 51533,
|
821 |
+
"<|23.38|>": 51534,
|
822 |
+
"<|23.40|>": 51535,
|
823 |
+
"<|23.42|>": 51536,
|
824 |
+
"<|23.44|>": 51537,
|
825 |
+
"<|23.46|>": 51538,
|
826 |
+
"<|23.48|>": 51539,
|
827 |
+
"<|23.50|>": 51540,
|
828 |
+
"<|23.52|>": 51541,
|
829 |
+
"<|23.54|>": 51542,
|
830 |
+
"<|23.56|>": 51543,
|
831 |
+
"<|23.58|>": 51544,
|
832 |
+
"<|23.60|>": 51545,
|
833 |
+
"<|23.62|>": 51546,
|
834 |
+
"<|23.64|>": 51547,
|
835 |
+
"<|23.66|>": 51548,
|
836 |
+
"<|23.68|>": 51549,
|
837 |
+
"<|23.70|>": 51550,
|
838 |
+
"<|23.72|>": 51551,
|
839 |
+
"<|23.74|>": 51552,
|
840 |
+
"<|23.76|>": 51553,
|
841 |
+
"<|23.78|>": 51554,
|
842 |
+
"<|23.80|>": 51555,
|
843 |
+
"<|23.82|>": 51556,
|
844 |
+
"<|23.84|>": 51557,
|
845 |
+
"<|23.86|>": 51558,
|
846 |
+
"<|23.88|>": 51559,
|
847 |
+
"<|23.90|>": 51560,
|
848 |
+
"<|23.92|>": 51561,
|
849 |
+
"<|23.94|>": 51562,
|
850 |
+
"<|23.96|>": 51563,
|
851 |
+
"<|23.98|>": 51564,
|
852 |
+
"<|24.00|>": 51565,
|
853 |
+
"<|24.02|>": 51566,
|
854 |
+
"<|24.04|>": 51567,
|
855 |
+
"<|24.06|>": 51568,
|
856 |
+
"<|24.08|>": 51569,
|
857 |
+
"<|24.10|>": 51570,
|
858 |
+
"<|24.12|>": 51571,
|
859 |
+
"<|24.14|>": 51572,
|
860 |
+
"<|24.16|>": 51573,
|
861 |
+
"<|24.18|>": 51574,
|
862 |
+
"<|24.20|>": 51575,
|
863 |
+
"<|24.22|>": 51576,
|
864 |
+
"<|24.24|>": 51577,
|
865 |
+
"<|24.26|>": 51578,
|
866 |
+
"<|24.28|>": 51579,
|
867 |
+
"<|24.30|>": 51580,
|
868 |
+
"<|24.32|>": 51581,
|
869 |
+
"<|24.34|>": 51582,
|
870 |
+
"<|24.36|>": 51583,
|
871 |
+
"<|24.38|>": 51584,
|
872 |
+
"<|24.40|>": 51585,
|
873 |
+
"<|24.42|>": 51586,
|
874 |
+
"<|24.44|>": 51587,
|
875 |
+
"<|24.46|>": 51588,
|
876 |
+
"<|24.48|>": 51589,
|
877 |
+
"<|24.50|>": 51590,
|
878 |
+
"<|24.52|>": 51591,
|
879 |
+
"<|24.54|>": 51592,
|
880 |
+
"<|24.56|>": 51593,
|
881 |
+
"<|24.58|>": 51594,
|
882 |
+
"<|24.60|>": 51595,
|
883 |
+
"<|24.62|>": 51596,
|
884 |
+
"<|24.64|>": 51597,
|
885 |
+
"<|24.66|>": 51598,
|
886 |
+
"<|24.68|>": 51599,
|
887 |
+
"<|24.70|>": 51600,
|
888 |
+
"<|24.72|>": 51601,
|
889 |
+
"<|24.74|>": 51602,
|
890 |
+
"<|24.76|>": 51603,
|
891 |
+
"<|24.78|>": 51604,
|
892 |
+
"<|24.80|>": 51605,
|
893 |
+
"<|24.82|>": 51606,
|
894 |
+
"<|24.84|>": 51607,
|
895 |
+
"<|24.86|>": 51608,
|
896 |
+
"<|24.88|>": 51609,
|
897 |
+
"<|24.90|>": 51610,
|
898 |
+
"<|24.92|>": 51611,
|
899 |
+
"<|24.94|>": 51612,
|
900 |
+
"<|24.96|>": 51613,
|
901 |
+
"<|24.98|>": 51614,
|
902 |
+
"<|25.00|>": 51615,
|
903 |
+
"<|25.02|>": 51616,
|
904 |
+
"<|25.04|>": 51617,
|
905 |
+
"<|25.06|>": 51618,
|
906 |
+
"<|25.08|>": 51619,
|
907 |
+
"<|25.10|>": 51620,
|
908 |
+
"<|25.12|>": 51621,
|
909 |
+
"<|25.14|>": 51622,
|
910 |
+
"<|25.16|>": 51623,
|
911 |
+
"<|25.18|>": 51624,
|
912 |
+
"<|25.20|>": 51625,
|
913 |
+
"<|25.22|>": 51626,
|
914 |
+
"<|25.24|>": 51627,
|
915 |
+
"<|25.26|>": 51628,
|
916 |
+
"<|25.28|>": 51629,
|
917 |
+
"<|25.30|>": 51630,
|
918 |
+
"<|25.32|>": 51631,
|
919 |
+
"<|25.34|>": 51632,
|
920 |
+
"<|25.36|>": 51633,
|
921 |
+
"<|25.38|>": 51634,
|
922 |
+
"<|25.40|>": 51635,
|
923 |
+
"<|25.42|>": 51636,
|
924 |
+
"<|25.44|>": 51637,
|
925 |
+
"<|25.46|>": 51638,
|
926 |
+
"<|25.48|>": 51639,
|
927 |
+
"<|25.50|>": 51640,
|
928 |
+
"<|25.52|>": 51641,
|
929 |
+
"<|25.54|>": 51642,
|
930 |
+
"<|25.56|>": 51643,
|
931 |
+
"<|25.58|>": 51644,
|
932 |
+
"<|25.60|>": 51645,
|
933 |
+
"<|25.62|>": 51646,
|
934 |
+
"<|25.64|>": 51647,
|
935 |
+
"<|25.66|>": 51648,
|
936 |
+
"<|25.68|>": 51649,
|
937 |
+
"<|25.70|>": 51650,
|
938 |
+
"<|25.72|>": 51651,
|
939 |
+
"<|25.74|>": 51652,
|
940 |
+
"<|25.76|>": 51653,
|
941 |
+
"<|25.78|>": 51654,
|
942 |
+
"<|25.80|>": 51655,
|
943 |
+
"<|25.82|>": 51656,
|
944 |
+
"<|25.84|>": 51657,
|
945 |
+
"<|25.86|>": 51658,
|
946 |
+
"<|25.88|>": 51659,
|
947 |
+
"<|25.90|>": 51660,
|
948 |
+
"<|25.92|>": 51661,
|
949 |
+
"<|25.94|>": 51662,
|
950 |
+
"<|25.96|>": 51663,
|
951 |
+
"<|25.98|>": 51664,
|
952 |
+
"<|26.00|>": 51665,
|
953 |
+
"<|26.02|>": 51666,
|
954 |
+
"<|26.04|>": 51667,
|
955 |
+
"<|26.06|>": 51668,
|
956 |
+
"<|26.08|>": 51669,
|
957 |
+
"<|26.10|>": 51670,
|
958 |
+
"<|26.12|>": 51671,
|
959 |
+
"<|26.14|>": 51672,
|
960 |
+
"<|26.16|>": 51673,
|
961 |
+
"<|26.18|>": 51674,
|
962 |
+
"<|26.20|>": 51675,
|
963 |
+
"<|26.22|>": 51676,
|
964 |
+
"<|26.24|>": 51677,
|
965 |
+
"<|26.26|>": 51678,
|
966 |
+
"<|26.28|>": 51679,
|
967 |
+
"<|26.30|>": 51680,
|
968 |
+
"<|26.32|>": 51681,
|
969 |
+
"<|26.34|>": 51682,
|
970 |
+
"<|26.36|>": 51683,
|
971 |
+
"<|26.38|>": 51684,
|
972 |
+
"<|26.40|>": 51685,
|
973 |
+
"<|26.42|>": 51686,
|
974 |
+
"<|26.44|>": 51687,
|
975 |
+
"<|26.46|>": 51688,
|
976 |
+
"<|26.48|>": 51689,
|
977 |
+
"<|26.50|>": 51690,
|
978 |
+
"<|26.52|>": 51691,
|
979 |
+
"<|26.54|>": 51692,
|
980 |
+
"<|26.56|>": 51693,
|
981 |
+
"<|26.58|>": 51694,
|
982 |
+
"<|26.60|>": 51695,
|
983 |
+
"<|26.62|>": 51696,
|
984 |
+
"<|26.64|>": 51697,
|
985 |
+
"<|26.66|>": 51698,
|
986 |
+
"<|26.68|>": 51699,
|
987 |
+
"<|26.70|>": 51700,
|
988 |
+
"<|26.72|>": 51701,
|
989 |
+
"<|26.74|>": 51702,
|
990 |
+
"<|26.76|>": 51703,
|
991 |
+
"<|26.78|>": 51704,
|
992 |
+
"<|26.80|>": 51705,
|
993 |
+
"<|26.82|>": 51706,
|
994 |
+
"<|26.84|>": 51707,
|
995 |
+
"<|26.86|>": 51708,
|
996 |
+
"<|26.88|>": 51709,
|
997 |
+
"<|26.90|>": 51710,
|
998 |
+
"<|26.92|>": 51711,
|
999 |
+
"<|26.94|>": 51712,
|
1000 |
+
"<|26.96|>": 51713,
|
1001 |
+
"<|26.98|>": 51714,
|
1002 |
+
"<|27.00|>": 51715,
|
1003 |
+
"<|27.02|>": 51716,
|
1004 |
+
"<|27.04|>": 51717,
|
1005 |
+
"<|27.06|>": 51718,
|
1006 |
+
"<|27.08|>": 51719,
|
1007 |
+
"<|27.10|>": 51720,
|
1008 |
+
"<|27.12|>": 51721,
|
1009 |
+
"<|27.14|>": 51722,
|
1010 |
+
"<|27.16|>": 51723,
|
1011 |
+
"<|27.18|>": 51724,
|
1012 |
+
"<|27.20|>": 51725,
|
1013 |
+
"<|27.22|>": 51726,
|
1014 |
+
"<|27.24|>": 51727,
|
1015 |
+
"<|27.26|>": 51728,
|
1016 |
+
"<|27.28|>": 51729,
|
1017 |
+
"<|27.30|>": 51730,
|
1018 |
+
"<|27.32|>": 51731,
|
1019 |
+
"<|27.34|>": 51732,
|
1020 |
+
"<|27.36|>": 51733,
|
1021 |
+
"<|27.38|>": 51734,
|
1022 |
+
"<|27.40|>": 51735,
|
1023 |
+
"<|27.42|>": 51736,
|
1024 |
+
"<|27.44|>": 51737,
|
1025 |
+
"<|27.46|>": 51738,
|
1026 |
+
"<|27.48|>": 51739,
|
1027 |
+
"<|27.50|>": 51740,
|
1028 |
+
"<|27.52|>": 51741,
|
1029 |
+
"<|27.54|>": 51742,
|
1030 |
+
"<|27.56|>": 51743,
|
1031 |
+
"<|27.58|>": 51744,
|
1032 |
+
"<|27.60|>": 51745,
|
1033 |
+
"<|27.62|>": 51746,
|
1034 |
+
"<|27.64|>": 51747,
|
1035 |
+
"<|27.66|>": 51748,
|
1036 |
+
"<|27.68|>": 51749,
|
1037 |
+
"<|27.70|>": 51750,
|
1038 |
+
"<|27.72|>": 51751,
|
1039 |
+
"<|27.74|>": 51752,
|
1040 |
+
"<|27.76|>": 51753,
|
1041 |
+
"<|27.78|>": 51754,
|
1042 |
+
"<|27.80|>": 51755,
|
1043 |
+
"<|27.82|>": 51756,
|
1044 |
+
"<|27.84|>": 51757,
|
1045 |
+
"<|27.86|>": 51758,
|
1046 |
+
"<|27.88|>": 51759,
|
1047 |
+
"<|27.90|>": 51760,
|
1048 |
+
"<|27.92|>": 51761,
|
1049 |
+
"<|27.94|>": 51762,
|
1050 |
+
"<|27.96|>": 51763,
|
1051 |
+
"<|27.98|>": 51764,
|
1052 |
+
"<|28.00|>": 51765,
|
1053 |
+
"<|28.02|>": 51766,
|
1054 |
+
"<|28.04|>": 51767,
|
1055 |
+
"<|28.06|>": 51768,
|
1056 |
+
"<|28.08|>": 51769,
|
1057 |
+
"<|28.10|>": 51770,
|
1058 |
+
"<|28.12|>": 51771,
|
1059 |
+
"<|28.14|>": 51772,
|
1060 |
+
"<|28.16|>": 51773,
|
1061 |
+
"<|28.18|>": 51774,
|
1062 |
+
"<|28.20|>": 51775,
|
1063 |
+
"<|28.22|>": 51776,
|
1064 |
+
"<|28.24|>": 51777,
|
1065 |
+
"<|28.26|>": 51778,
|
1066 |
+
"<|28.28|>": 51779,
|
1067 |
+
"<|28.30|>": 51780,
|
1068 |
+
"<|28.32|>": 51781,
|
1069 |
+
"<|28.34|>": 51782,
|
1070 |
+
"<|28.36|>": 51783,
|
1071 |
+
"<|28.38|>": 51784,
|
1072 |
+
"<|28.40|>": 51785,
|
1073 |
+
"<|28.42|>": 51786,
|
1074 |
+
"<|28.44|>": 51787,
|
1075 |
+
"<|28.46|>": 51788,
|
1076 |
+
"<|28.48|>": 51789,
|
1077 |
+
"<|28.50|>": 51790,
|
1078 |
+
"<|28.52|>": 51791,
|
1079 |
+
"<|28.54|>": 51792,
|
1080 |
+
"<|28.56|>": 51793,
|
1081 |
+
"<|28.58|>": 51794,
|
1082 |
+
"<|28.60|>": 51795,
|
1083 |
+
"<|28.62|>": 51796,
|
1084 |
+
"<|28.64|>": 51797,
|
1085 |
+
"<|28.66|>": 51798,
|
1086 |
+
"<|28.68|>": 51799,
|
1087 |
+
"<|28.70|>": 51800,
|
1088 |
+
"<|28.72|>": 51801,
|
1089 |
+
"<|28.74|>": 51802,
|
1090 |
+
"<|28.76|>": 51803,
|
1091 |
+
"<|28.78|>": 51804,
|
1092 |
+
"<|28.80|>": 51805,
|
1093 |
+
"<|28.82|>": 51806,
|
1094 |
+
"<|28.84|>": 51807,
|
1095 |
+
"<|28.86|>": 51808,
|
1096 |
+
"<|28.88|>": 51809,
|
1097 |
+
"<|28.90|>": 51810,
|
1098 |
+
"<|28.92|>": 51811,
|
1099 |
+
"<|28.94|>": 51812,
|
1100 |
+
"<|28.96|>": 51813,
|
1101 |
+
"<|28.98|>": 51814,
|
1102 |
+
"<|29.00|>": 51815,
|
1103 |
+
"<|29.02|>": 51816,
|
1104 |
+
"<|29.04|>": 51817,
|
1105 |
+
"<|29.06|>": 51818,
|
1106 |
+
"<|29.08|>": 51819,
|
1107 |
+
"<|29.10|>": 51820,
|
1108 |
+
"<|29.12|>": 51821,
|
1109 |
+
"<|29.14|>": 51822,
|
1110 |
+
"<|29.16|>": 51823,
|
1111 |
+
"<|29.18|>": 51824,
|
1112 |
+
"<|29.20|>": 51825,
|
1113 |
+
"<|29.22|>": 51826,
|
1114 |
+
"<|29.24|>": 51827,
|
1115 |
+
"<|29.26|>": 51828,
|
1116 |
+
"<|29.28|>": 51829,
|
1117 |
+
"<|29.30|>": 51830,
|
1118 |
+
"<|29.32|>": 51831,
|
1119 |
+
"<|29.34|>": 51832,
|
1120 |
+
"<|29.36|>": 51833,
|
1121 |
+
"<|29.38|>": 51834,
|
1122 |
+
"<|29.40|>": 51835,
|
1123 |
+
"<|29.42|>": 51836,
|
1124 |
+
"<|29.44|>": 51837,
|
1125 |
+
"<|29.46|>": 51838,
|
1126 |
+
"<|29.48|>": 51839,
|
1127 |
+
"<|29.50|>": 51840,
|
1128 |
+
"<|29.52|>": 51841,
|
1129 |
+
"<|29.54|>": 51842,
|
1130 |
+
"<|29.56|>": 51843,
|
1131 |
+
"<|29.58|>": 51844,
|
1132 |
+
"<|29.60|>": 51845,
|
1133 |
+
"<|29.62|>": 51846,
|
1134 |
+
"<|29.64|>": 51847,
|
1135 |
+
"<|29.66|>": 51848,
|
1136 |
+
"<|29.68|>": 51849,
|
1137 |
+
"<|29.70|>": 51850,
|
1138 |
+
"<|29.72|>": 51851,
|
1139 |
+
"<|29.74|>": 51852,
|
1140 |
+
"<|29.76|>": 51853,
|
1141 |
+
"<|29.78|>": 51854,
|
1142 |
+
"<|29.80|>": 51855,
|
1143 |
+
"<|29.82|>": 51856,
|
1144 |
+
"<|29.84|>": 51857,
|
1145 |
+
"<|29.86|>": 51858,
|
1146 |
+
"<|29.88|>": 51859,
|
1147 |
+
"<|29.90|>": 51860,
|
1148 |
+
"<|29.92|>": 51861,
|
1149 |
+
"<|29.94|>": 51862,
|
1150 |
+
"<|29.96|>": 51863,
|
1151 |
+
"<|29.98|>": 51864,
|
1152 |
+
"<|3.00|>": 50515,
|
1153 |
+
"<|3.02|>": 50516,
|
1154 |
+
"<|3.04|>": 50517,
|
1155 |
+
"<|3.06|>": 50518,
|
1156 |
+
"<|3.08|>": 50519,
|
1157 |
+
"<|3.10|>": 50520,
|
1158 |
+
"<|3.12|>": 50521,
|
1159 |
+
"<|3.14|>": 50522,
|
1160 |
+
"<|3.16|>": 50523,
|
1161 |
+
"<|3.18|>": 50524,
|
1162 |
+
"<|3.20|>": 50525,
|
1163 |
+
"<|3.22|>": 50526,
|
1164 |
+
"<|3.24|>": 50527,
|
1165 |
+
"<|3.26|>": 50528,
|
1166 |
+
"<|3.28|>": 50529,
|
1167 |
+
"<|3.30|>": 50530,
|
1168 |
+
"<|3.32|>": 50531,
|
1169 |
+
"<|3.34|>": 50532,
|
1170 |
+
"<|3.36|>": 50533,
|
1171 |
+
"<|3.38|>": 50534,
|
1172 |
+
"<|3.40|>": 50535,
|
1173 |
+
"<|3.42|>": 50536,
|
1174 |
+
"<|3.44|>": 50537,
|
1175 |
+
"<|3.46|>": 50538,
|
1176 |
+
"<|3.48|>": 50539,
|
1177 |
+
"<|3.50|>": 50540,
|
1178 |
+
"<|3.52|>": 50541,
|
1179 |
+
"<|3.54|>": 50542,
|
1180 |
+
"<|3.56|>": 50543,
|
1181 |
+
"<|3.58|>": 50544,
|
1182 |
+
"<|3.60|>": 50545,
|
1183 |
+
"<|3.62|>": 50546,
|
1184 |
+
"<|3.64|>": 50547,
|
1185 |
+
"<|3.66|>": 50548,
|
1186 |
+
"<|3.68|>": 50549,
|
1187 |
+
"<|3.70|>": 50550,
|
1188 |
+
"<|3.72|>": 50551,
|
1189 |
+
"<|3.74|>": 50552,
|
1190 |
+
"<|3.76|>": 50553,
|
1191 |
+
"<|3.78|>": 50554,
|
1192 |
+
"<|3.80|>": 50555,
|
1193 |
+
"<|3.82|>": 50556,
|
1194 |
+
"<|3.84|>": 50557,
|
1195 |
+
"<|3.86|>": 50558,
|
1196 |
+
"<|3.88|>": 50559,
|
1197 |
+
"<|3.90|>": 50560,
|
1198 |
+
"<|3.92|>": 50561,
|
1199 |
+
"<|3.94|>": 50562,
|
1200 |
+
"<|3.96|>": 50563,
|
1201 |
+
"<|3.98|>": 50564,
|
1202 |
+
"<|30.00|>": 51865,
|
1203 |
+
"<|4.00|>": 50565,
|
1204 |
+
"<|4.02|>": 50566,
|
1205 |
+
"<|4.04|>": 50567,
|
1206 |
+
"<|4.06|>": 50568,
|
1207 |
+
"<|4.08|>": 50569,
|
1208 |
+
"<|4.10|>": 50570,
|
1209 |
+
"<|4.12|>": 50571,
|
1210 |
+
"<|4.14|>": 50572,
|
1211 |
+
"<|4.16|>": 50573,
|
1212 |
+
"<|4.18|>": 50574,
|
1213 |
+
"<|4.20|>": 50575,
|
1214 |
+
"<|4.22|>": 50576,
|
1215 |
+
"<|4.24|>": 50577,
|
1216 |
+
"<|4.26|>": 50578,
|
1217 |
+
"<|4.28|>": 50579,
|
1218 |
+
"<|4.30|>": 50580,
|
1219 |
+
"<|4.32|>": 50581,
|
1220 |
+
"<|4.34|>": 50582,
|
1221 |
+
"<|4.36|>": 50583,
|
1222 |
+
"<|4.38|>": 50584,
|
1223 |
+
"<|4.40|>": 50585,
|
1224 |
+
"<|4.42|>": 50586,
|
1225 |
+
"<|4.44|>": 50587,
|
1226 |
+
"<|4.46|>": 50588,
|
1227 |
+
"<|4.48|>": 50589,
|
1228 |
+
"<|4.50|>": 50590,
|
1229 |
+
"<|4.52|>": 50591,
|
1230 |
+
"<|4.54|>": 50592,
|
1231 |
+
"<|4.56|>": 50593,
|
1232 |
+
"<|4.58|>": 50594,
|
1233 |
+
"<|4.60|>": 50595,
|
1234 |
+
"<|4.62|>": 50596,
|
1235 |
+
"<|4.64|>": 50597,
|
1236 |
+
"<|4.66|>": 50598,
|
1237 |
+
"<|4.68|>": 50599,
|
1238 |
+
"<|4.70|>": 50600,
|
1239 |
+
"<|4.72|>": 50601,
|
1240 |
+
"<|4.74|>": 50602,
|
1241 |
+
"<|4.76|>": 50603,
|
1242 |
+
"<|4.78|>": 50604,
|
1243 |
+
"<|4.80|>": 50605,
|
1244 |
+
"<|4.82|>": 50606,
|
1245 |
+
"<|4.84|>": 50607,
|
1246 |
+
"<|4.86|>": 50608,
|
1247 |
+
"<|4.88|>": 50609,
|
1248 |
+
"<|4.90|>": 50610,
|
1249 |
+
"<|4.92|>": 50611,
|
1250 |
+
"<|4.94|>": 50612,
|
1251 |
+
"<|4.96|>": 50613,
|
1252 |
+
"<|4.98|>": 50614,
|
1253 |
+
"<|5.00|>": 50615,
|
1254 |
+
"<|5.02|>": 50616,
|
1255 |
+
"<|5.04|>": 50617,
|
1256 |
+
"<|5.06|>": 50618,
|
1257 |
+
"<|5.08|>": 50619,
|
1258 |
+
"<|5.10|>": 50620,
|
1259 |
+
"<|5.12|>": 50621,
|
1260 |
+
"<|5.14|>": 50622,
|
1261 |
+
"<|5.16|>": 50623,
|
1262 |
+
"<|5.18|>": 50624,
|
1263 |
+
"<|5.20|>": 50625,
|
1264 |
+
"<|5.22|>": 50626,
|
1265 |
+
"<|5.24|>": 50627,
|
1266 |
+
"<|5.26|>": 50628,
|
1267 |
+
"<|5.28|>": 50629,
|
1268 |
+
"<|5.30|>": 50630,
|
1269 |
+
"<|5.32|>": 50631,
|
1270 |
+
"<|5.34|>": 50632,
|
1271 |
+
"<|5.36|>": 50633,
|
1272 |
+
"<|5.38|>": 50634,
|
1273 |
+
"<|5.40|>": 50635,
|
1274 |
+
"<|5.42|>": 50636,
|
1275 |
+
"<|5.44|>": 50637,
|
1276 |
+
"<|5.46|>": 50638,
|
1277 |
+
"<|5.48|>": 50639,
|
1278 |
+
"<|5.50|>": 50640,
|
1279 |
+
"<|5.52|>": 50641,
|
1280 |
+
"<|5.54|>": 50642,
|
1281 |
+
"<|5.56|>": 50643,
|
1282 |
+
"<|5.58|>": 50644,
|
1283 |
+
"<|5.60|>": 50645,
|
1284 |
+
"<|5.62|>": 50646,
|
1285 |
+
"<|5.64|>": 50647,
|
1286 |
+
"<|5.66|>": 50648,
|
1287 |
+
"<|5.68|>": 50649,
|
1288 |
+
"<|5.70|>": 50650,
|
1289 |
+
"<|5.72|>": 50651,
|
1290 |
+
"<|5.74|>": 50652,
|
1291 |
+
"<|5.76|>": 50653,
|
1292 |
+
"<|5.78|>": 50654,
|
1293 |
+
"<|5.80|>": 50655,
|
1294 |
+
"<|5.82|>": 50656,
|
1295 |
+
"<|5.84|>": 50657,
|
1296 |
+
"<|5.86|>": 50658,
|
1297 |
+
"<|5.88|>": 50659,
|
1298 |
+
"<|5.90|>": 50660,
|
1299 |
+
"<|5.92|>": 50661,
|
1300 |
+
"<|5.94|>": 50662,
|
1301 |
+
"<|5.96|>": 50663,
|
1302 |
+
"<|5.98|>": 50664,
|
1303 |
+
"<|6.00|>": 50665,
|
1304 |
+
"<|6.02|>": 50666,
|
1305 |
+
"<|6.04|>": 50667,
|
1306 |
+
"<|6.06|>": 50668,
|
1307 |
+
"<|6.08|>": 50669,
|
1308 |
+
"<|6.10|>": 50670,
|
1309 |
+
"<|6.12|>": 50671,
|
1310 |
+
"<|6.14|>": 50672,
|
1311 |
+
"<|6.16|>": 50673,
|
1312 |
+
"<|6.18|>": 50674,
|
1313 |
+
"<|6.20|>": 50675,
|
1314 |
+
"<|6.22|>": 50676,
|
1315 |
+
"<|6.24|>": 50677,
|
1316 |
+
"<|6.26|>": 50678,
|
1317 |
+
"<|6.28|>": 50679,
|
1318 |
+
"<|6.30|>": 50680,
|
1319 |
+
"<|6.32|>": 50681,
|
1320 |
+
"<|6.34|>": 50682,
|
1321 |
+
"<|6.36|>": 50683,
|
1322 |
+
"<|6.38|>": 50684,
|
1323 |
+
"<|6.40|>": 50685,
|
1324 |
+
"<|6.42|>": 50686,
|
1325 |
+
"<|6.44|>": 50687,
|
1326 |
+
"<|6.46|>": 50688,
|
1327 |
+
"<|6.48|>": 50689,
|
1328 |
+
"<|6.50|>": 50690,
|
1329 |
+
"<|6.52|>": 50691,
|
1330 |
+
"<|6.54|>": 50692,
|
1331 |
+
"<|6.56|>": 50693,
|
1332 |
+
"<|6.58|>": 50694,
|
1333 |
+
"<|6.60|>": 50695,
|
1334 |
+
"<|6.62|>": 50696,
|
1335 |
+
"<|6.64|>": 50697,
|
1336 |
+
"<|6.66|>": 50698,
|
1337 |
+
"<|6.68|>": 50699,
|
1338 |
+
"<|6.70|>": 50700,
|
1339 |
+
"<|6.72|>": 50701,
|
1340 |
+
"<|6.74|>": 50702,
|
1341 |
+
"<|6.76|>": 50703,
|
1342 |
+
"<|6.78|>": 50704,
|
1343 |
+
"<|6.80|>": 50705,
|
1344 |
+
"<|6.82|>": 50706,
|
1345 |
+
"<|6.84|>": 50707,
|
1346 |
+
"<|6.86|>": 50708,
|
1347 |
+
"<|6.88|>": 50709,
|
1348 |
+
"<|6.90|>": 50710,
|
1349 |
+
"<|6.92|>": 50711,
|
1350 |
+
"<|6.94|>": 50712,
|
1351 |
+
"<|6.96|>": 50713,
|
1352 |
+
"<|6.98|>": 50714,
|
1353 |
+
"<|7.00|>": 50715,
|
1354 |
+
"<|7.02|>": 50716,
|
1355 |
+
"<|7.04|>": 50717,
|
1356 |
+
"<|7.06|>": 50718,
|
1357 |
+
"<|7.08|>": 50719,
|
1358 |
+
"<|7.10|>": 50720,
|
1359 |
+
"<|7.12|>": 50721,
|
1360 |
+
"<|7.14|>": 50722,
|
1361 |
+
"<|7.16|>": 50723,
|
1362 |
+
"<|7.18|>": 50724,
|
1363 |
+
"<|7.20|>": 50725,
|
1364 |
+
"<|7.22|>": 50726,
|
1365 |
+
"<|7.24|>": 50727,
|
1366 |
+
"<|7.26|>": 50728,
|
1367 |
+
"<|7.28|>": 50729,
|
1368 |
+
"<|7.30|>": 50730,
|
1369 |
+
"<|7.32|>": 50731,
|
1370 |
+
"<|7.34|>": 50732,
|
1371 |
+
"<|7.36|>": 50733,
|
1372 |
+
"<|7.38|>": 50734,
|
1373 |
+
"<|7.40|>": 50735,
|
1374 |
+
"<|7.42|>": 50736,
|
1375 |
+
"<|7.44|>": 50737,
|
1376 |
+
"<|7.46|>": 50738,
|
1377 |
+
"<|7.48|>": 50739,
|
1378 |
+
"<|7.50|>": 50740,
|
1379 |
+
"<|7.52|>": 50741,
|
1380 |
+
"<|7.54|>": 50742,
|
1381 |
+
"<|7.56|>": 50743,
|
1382 |
+
"<|7.58|>": 50744,
|
1383 |
+
"<|7.60|>": 50745,
|
1384 |
+
"<|7.62|>": 50746,
|
1385 |
+
"<|7.64|>": 50747,
|
1386 |
+
"<|7.66|>": 50748,
|
1387 |
+
"<|7.68|>": 50749,
|
1388 |
+
"<|7.70|>": 50750,
|
1389 |
+
"<|7.72|>": 50751,
|
1390 |
+
"<|7.74|>": 50752,
|
1391 |
+
"<|7.76|>": 50753,
|
1392 |
+
"<|7.78|>": 50754,
|
1393 |
+
"<|7.80|>": 50755,
|
1394 |
+
"<|7.82|>": 50756,
|
1395 |
+
"<|7.84|>": 50757,
|
1396 |
+
"<|7.86|>": 50758,
|
1397 |
+
"<|7.88|>": 50759,
|
1398 |
+
"<|7.90|>": 50760,
|
1399 |
+
"<|7.92|>": 50761,
|
1400 |
+
"<|7.94|>": 50762,
|
1401 |
+
"<|7.96|>": 50763,
|
1402 |
+
"<|7.98|>": 50764,
|
1403 |
+
"<|8.00|>": 50765,
|
1404 |
+
"<|8.02|>": 50766,
|
1405 |
+
"<|8.04|>": 50767,
|
1406 |
+
"<|8.06|>": 50768,
|
1407 |
+
"<|8.08|>": 50769,
|
1408 |
+
"<|8.10|>": 50770,
|
1409 |
+
"<|8.12|>": 50771,
|
1410 |
+
"<|8.14|>": 50772,
|
1411 |
+
"<|8.16|>": 50773,
|
1412 |
+
"<|8.18|>": 50774,
|
1413 |
+
"<|8.20|>": 50775,
|
1414 |
+
"<|8.22|>": 50776,
|
1415 |
+
"<|8.24|>": 50777,
|
1416 |
+
"<|8.26|>": 50778,
|
1417 |
+
"<|8.28|>": 50779,
|
1418 |
+
"<|8.30|>": 50780,
|
1419 |
+
"<|8.32|>": 50781,
|
1420 |
+
"<|8.34|>": 50782,
|
1421 |
+
"<|8.36|>": 50783,
|
1422 |
+
"<|8.38|>": 50784,
|
1423 |
+
"<|8.40|>": 50785,
|
1424 |
+
"<|8.42|>": 50786,
|
1425 |
+
"<|8.44|>": 50787,
|
1426 |
+
"<|8.46|>": 50788,
|
1427 |
+
"<|8.48|>": 50789,
|
1428 |
+
"<|8.50|>": 50790,
|
1429 |
+
"<|8.52|>": 50791,
|
1430 |
+
"<|8.54|>": 50792,
|
1431 |
+
"<|8.56|>": 50793,
|
1432 |
+
"<|8.58|>": 50794,
|
1433 |
+
"<|8.60|>": 50795,
|
1434 |
+
"<|8.62|>": 50796,
|
1435 |
+
"<|8.64|>": 50797,
|
1436 |
+
"<|8.66|>": 50798,
|
1437 |
+
"<|8.68|>": 50799,
|
1438 |
+
"<|8.70|>": 50800,
|
1439 |
+
"<|8.72|>": 50801,
|
1440 |
+
"<|8.74|>": 50802,
|
1441 |
+
"<|8.76|>": 50803,
|
1442 |
+
"<|8.78|>": 50804,
|
1443 |
+
"<|8.80|>": 50805,
|
1444 |
+
"<|8.82|>": 50806,
|
1445 |
+
"<|8.84|>": 50807,
|
1446 |
+
"<|8.86|>": 50808,
|
1447 |
+
"<|8.88|>": 50809,
|
1448 |
+
"<|8.90|>": 50810,
|
1449 |
+
"<|8.92|>": 50811,
|
1450 |
+
"<|8.94|>": 50812,
|
1451 |
+
"<|8.96|>": 50813,
|
1452 |
+
"<|8.98|>": 50814,
|
1453 |
+
"<|9.00|>": 50815,
|
1454 |
+
"<|9.02|>": 50816,
|
1455 |
+
"<|9.04|>": 50817,
|
1456 |
+
"<|9.06|>": 50818,
|
1457 |
+
"<|9.08|>": 50819,
|
1458 |
+
"<|9.10|>": 50820,
|
1459 |
+
"<|9.12|>": 50821,
|
1460 |
+
"<|9.14|>": 50822,
|
1461 |
+
"<|9.16|>": 50823,
|
1462 |
+
"<|9.18|>": 50824,
|
1463 |
+
"<|9.20|>": 50825,
|
1464 |
+
"<|9.22|>": 50826,
|
1465 |
+
"<|9.24|>": 50827,
|
1466 |
+
"<|9.26|>": 50828,
|
1467 |
+
"<|9.28|>": 50829,
|
1468 |
+
"<|9.30|>": 50830,
|
1469 |
+
"<|9.32|>": 50831,
|
1470 |
+
"<|9.34|>": 50832,
|
1471 |
+
"<|9.36|>": 50833,
|
1472 |
+
"<|9.38|>": 50834,
|
1473 |
+
"<|9.40|>": 50835,
|
1474 |
+
"<|9.42|>": 50836,
|
1475 |
+
"<|9.44|>": 50837,
|
1476 |
+
"<|9.46|>": 50838,
|
1477 |
+
"<|9.48|>": 50839,
|
1478 |
+
"<|9.50|>": 50840,
|
1479 |
+
"<|9.52|>": 50841,
|
1480 |
+
"<|9.54|>": 50842,
|
1481 |
+
"<|9.56|>": 50843,
|
1482 |
+
"<|9.58|>": 50844,
|
1483 |
+
"<|9.60|>": 50845,
|
1484 |
+
"<|9.62|>": 50846,
|
1485 |
+
"<|9.64|>": 50847,
|
1486 |
+
"<|9.66|>": 50848,
|
1487 |
+
"<|9.68|>": 50849,
|
1488 |
+
"<|9.70|>": 50850,
|
1489 |
+
"<|9.72|>": 50851,
|
1490 |
+
"<|9.74|>": 50852,
|
1491 |
+
"<|9.76|>": 50853,
|
1492 |
+
"<|9.78|>": 50854,
|
1493 |
+
"<|9.80|>": 50855,
|
1494 |
+
"<|9.82|>": 50856,
|
1495 |
+
"<|9.84|>": 50857,
|
1496 |
+
"<|9.86|>": 50858,
|
1497 |
+
"<|9.88|>": 50859,
|
1498 |
+
"<|9.90|>": 50860,
|
1499 |
+
"<|9.92|>": 50861,
|
1500 |
+
"<|9.94|>": 50862,
|
1501 |
+
"<|9.96|>": 50863,
|
1502 |
+
"<|9.98|>": 50864,
|
1503 |
+
"<|af|>": 50327,
|
1504 |
+
"<|am|>": 50334,
|
1505 |
+
"<|ar|>": 50272,
|
1506 |
+
"<|as|>": 50350,
|
1507 |
+
"<|az|>": 50304,
|
1508 |
+
"<|ba|>": 50355,
|
1509 |
+
"<|be|>": 50330,
|
1510 |
+
"<|bg|>": 50292,
|
1511 |
+
"<|bn|>": 50302,
|
1512 |
+
"<|bo|>": 50347,
|
1513 |
+
"<|br|>": 50309,
|
1514 |
+
"<|bs|>": 50315,
|
1515 |
+
"<|ca|>": 50270,
|
1516 |
+
"<|cs|>": 50283,
|
1517 |
+
"<|cy|>": 50297,
|
1518 |
+
"<|da|>": 50285,
|
1519 |
+
"<|de|>": 50261,
|
1520 |
+
"<|el|>": 50281,
|
1521 |
+
"<|endoftext|>": 50257,
|
1522 |
+
"<|en|>": 50259,
|
1523 |
+
"<|es|>": 50262,
|
1524 |
+
"<|et|>": 50307,
|
1525 |
+
"<|eu|>": 50310,
|
1526 |
+
"<|fa|>": 50300,
|
1527 |
+
"<|fi|>": 50277,
|
1528 |
+
"<|fo|>": 50338,
|
1529 |
+
"<|fr|>": 50265,
|
1530 |
+
"<|gl|>": 50319,
|
1531 |
+
"<|gu|>": 50333,
|
1532 |
+
"<|haw|>": 50352,
|
1533 |
+
"<|ha|>": 50354,
|
1534 |
+
"<|he|>": 50279,
|
1535 |
+
"<|hi|>": 50276,
|
1536 |
+
"<|hr|>": 50291,
|
1537 |
+
"<|ht|>": 50339,
|
1538 |
+
"<|hu|>": 50286,
|
1539 |
+
"<|hy|>": 50312,
|
1540 |
+
"<|id|>": 50275,
|
1541 |
+
"<|is|>": 50311,
|
1542 |
+
"<|it|>": 50274,
|
1543 |
+
"<|ja|>": 50266,
|
1544 |
+
"<|jw|>": 50356,
|
1545 |
+
"<|ka|>": 50329,
|
1546 |
+
"<|kk|>": 50316,
|
1547 |
+
"<|km|>": 50323,
|
1548 |
+
"<|kn|>": 50306,
|
1549 |
+
"<|ko|>": 50264,
|
1550 |
+
"<|la|>": 50294,
|
1551 |
+
"<|lb|>": 50345,
|
1552 |
+
"<|ln|>": 50353,
|
1553 |
+
"<|lo|>": 50336,
|
1554 |
+
"<|lt|>": 50293,
|
1555 |
+
"<|lv|>": 50301,
|
1556 |
+
"<|mg|>": 50349,
|
1557 |
+
"<|mi|>": 50295,
|
1558 |
+
"<|mk|>": 50308,
|
1559 |
+
"<|ml|>": 50296,
|
1560 |
+
"<|mn|>": 50314,
|
1561 |
+
"<|mr|>": 50320,
|
1562 |
+
"<|ms|>": 50282,
|
1563 |
+
"<|mt|>": 50343,
|
1564 |
+
"<|my|>": 50346,
|
1565 |
+
"<|ne|>": 50313,
|
1566 |
+
"<|nl|>": 50271,
|
1567 |
+
"<|nn|>": 50342,
|
1568 |
+
"<|nospeech|>": 50363,
|
1569 |
+
"<|notimestamps|>": 50364,
|
1570 |
+
"<|no|>": 50288,
|
1571 |
+
"<|oc|>": 50328,
|
1572 |
+
"<|pa|>": 50321,
|
1573 |
+
"<|pl|>": 50269,
|
1574 |
+
"<|ps|>": 50340,
|
1575 |
+
"<|pt|>": 50267,
|
1576 |
+
"<|ro|>": 50284,
|
1577 |
+
"<|ru|>": 50263,
|
1578 |
+
"<|sa|>": 50344,
|
1579 |
+
"<|sd|>": 50332,
|
1580 |
+
"<|si|>": 50322,
|
1581 |
+
"<|sk|>": 50298,
|
1582 |
+
"<|sl|>": 50305,
|
1583 |
+
"<|sn|>": 50324,
|
1584 |
+
"<|so|>": 50326,
|
1585 |
+
"<|sq|>": 50317,
|
1586 |
+
"<|sr|>": 50303,
|
1587 |
+
"<|startoflm|>": 50361,
|
1588 |
+
"<|startofprev|>": 50362,
|
1589 |
+
"<|startoftranscript|>": 50258,
|
1590 |
+
"<|su|>": 50357,
|
1591 |
+
"<|sv|>": 50273,
|
1592 |
+
"<|sw|>": 50318,
|
1593 |
+
"<|ta|>": 50287,
|
1594 |
+
"<|te|>": 50299,
|
1595 |
+
"<|tg|>": 50331,
|
1596 |
+
"<|th|>": 50289,
|
1597 |
+
"<|tk|>": 50341,
|
1598 |
+
"<|tl|>": 50348,
|
1599 |
+
"<|transcribe|>": 50360,
|
1600 |
+
"<|translate|>": 50359,
|
1601 |
+
"<|tr|>": 50268,
|
1602 |
+
"<|tt|>": 50351,
|
1603 |
+
"<|uk|>": 50280,
|
1604 |
+
"<|ur|>": 50290,
|
1605 |
+
"<|uz|>": 50337,
|
1606 |
+
"<|vi|>": 50278,
|
1607 |
+
"<|yi|>": 50335,
|
1608 |
+
"<|yo|>": 50325,
|
1609 |
+
"<|yue|>": 50358,
|
1610 |
+
"<|zh|>": 50260
|
1611 |
+
}
|
distil-large-v3-init/config.json
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "GalaktischeGurke/swhisper_large_8552",
|
3 |
+
"activation_dropout": 0.0,
|
4 |
+
"activation_function": "gelu",
|
5 |
+
"apply_spec_augment": false,
|
6 |
+
"architectures": [
|
7 |
+
"WhisperForConditionalGeneration"
|
8 |
+
],
|
9 |
+
"attention_dropout": 0.0,
|
10 |
+
"begin_suppress_tokens": [
|
11 |
+
220,
|
12 |
+
50257
|
13 |
+
],
|
14 |
+
"bos_token_id": 50257,
|
15 |
+
"classifier_proj_size": 256,
|
16 |
+
"d_model": 1280,
|
17 |
+
"decoder_attention_heads": 20,
|
18 |
+
"decoder_ffn_dim": 5120,
|
19 |
+
"decoder_layerdrop": 0.0,
|
20 |
+
"decoder_layers": 2,
|
21 |
+
"decoder_start_token_id": 50258,
|
22 |
+
"dropout": 0.0,
|
23 |
+
"encoder_attention_heads": 20,
|
24 |
+
"encoder_ffn_dim": 5120,
|
25 |
+
"encoder_layerdrop": 0.0,
|
26 |
+
"encoder_layers": 32,
|
27 |
+
"eos_token_id": 50257,
|
28 |
+
"init_std": 0.02,
|
29 |
+
"is_encoder_decoder": true,
|
30 |
+
"mask_feature_length": 10,
|
31 |
+
"mask_feature_min_masks": 0,
|
32 |
+
"mask_feature_prob": 0.0,
|
33 |
+
"mask_time_length": 10,
|
34 |
+
"mask_time_min_masks": 2,
|
35 |
+
"mask_time_prob": 0.05,
|
36 |
+
"max_length": 448,
|
37 |
+
"max_source_positions": 1500,
|
38 |
+
"max_target_positions": 448,
|
39 |
+
"median_filter_width": 7,
|
40 |
+
"model_type": "whisper",
|
41 |
+
"num_hidden_layers": 32,
|
42 |
+
"num_mel_bins": 128,
|
43 |
+
"pad_token_id": 50256,
|
44 |
+
"scale_embedding": false,
|
45 |
+
"torch_dtype": "float32",
|
46 |
+
"transformers_version": "4.41.2",
|
47 |
+
"use_cache": true,
|
48 |
+
"use_weighted_layer_sum": false,
|
49 |
+
"vocab_size": 51866
|
50 |
+
}
|
distil-large-v3-init/generation_config.json
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alignment_heads": [
|
3 |
+
[
|
4 |
+
7,
|
5 |
+
0
|
6 |
+
],
|
7 |
+
[
|
8 |
+
10,
|
9 |
+
17
|
10 |
+
],
|
11 |
+
[
|
12 |
+
12,
|
13 |
+
18
|
14 |
+
],
|
15 |
+
[
|
16 |
+
13,
|
17 |
+
12
|
18 |
+
],
|
19 |
+
[
|
20 |
+
16,
|
21 |
+
1
|
22 |
+
],
|
23 |
+
[
|
24 |
+
17,
|
25 |
+
14
|
26 |
+
],
|
27 |
+
[
|
28 |
+
19,
|
29 |
+
11
|
30 |
+
],
|
31 |
+
[
|
32 |
+
21,
|
33 |
+
4
|
34 |
+
],
|
35 |
+
[
|
36 |
+
24,
|
37 |
+
1
|
38 |
+
],
|
39 |
+
[
|
40 |
+
25,
|
41 |
+
6
|
42 |
+
]
|
43 |
+
],
|
44 |
+
"begin_suppress_tokens": [
|
45 |
+
220,
|
46 |
+
50257
|
47 |
+
],
|
48 |
+
"bos_token_id": 50257,
|
49 |
+
"decoder_start_token_id": 50258,
|
50 |
+
"eos_token_id": 50257,
|
51 |
+
"is_multilingual": true,
|
52 |
+
"lang_to_id": {
|
53 |
+
"<|af|>": 50327,
|
54 |
+
"<|am|>": 50334,
|
55 |
+
"<|ar|>": 50272,
|
56 |
+
"<|as|>": 50350,
|
57 |
+
"<|az|>": 50304,
|
58 |
+
"<|ba|>": 50355,
|
59 |
+
"<|be|>": 50330,
|
60 |
+
"<|bg|>": 50292,
|
61 |
+
"<|bn|>": 50302,
|
62 |
+
"<|bo|>": 50347,
|
63 |
+
"<|br|>": 50309,
|
64 |
+
"<|bs|>": 50315,
|
65 |
+
"<|ca|>": 50270,
|
66 |
+
"<|cs|>": 50283,
|
67 |
+
"<|cy|>": 50297,
|
68 |
+
"<|da|>": 50285,
|
69 |
+
"<|de|>": 50261,
|
70 |
+
"<|el|>": 50281,
|
71 |
+
"<|en|>": 50259,
|
72 |
+
"<|es|>": 50262,
|
73 |
+
"<|et|>": 50307,
|
74 |
+
"<|eu|>": 50310,
|
75 |
+
"<|fa|>": 50300,
|
76 |
+
"<|fi|>": 50277,
|
77 |
+
"<|fo|>": 50338,
|
78 |
+
"<|fr|>": 50265,
|
79 |
+
"<|gl|>": 50319,
|
80 |
+
"<|gu|>": 50333,
|
81 |
+
"<|haw|>": 50352,
|
82 |
+
"<|ha|>": 50354,
|
83 |
+
"<|he|>": 50279,
|
84 |
+
"<|hi|>": 50276,
|
85 |
+
"<|hr|>": 50291,
|
86 |
+
"<|ht|>": 50339,
|
87 |
+
"<|hu|>": 50286,
|
88 |
+
"<|hy|>": 50312,
|
89 |
+
"<|id|>": 50275,
|
90 |
+
"<|is|>": 50311,
|
91 |
+
"<|it|>": 50274,
|
92 |
+
"<|ja|>": 50266,
|
93 |
+
"<|jw|>": 50356,
|
94 |
+
"<|ka|>": 50329,
|
95 |
+
"<|kk|>": 50316,
|
96 |
+
"<|km|>": 50323,
|
97 |
+
"<|kn|>": 50306,
|
98 |
+
"<|ko|>": 50264,
|
99 |
+
"<|la|>": 50294,
|
100 |
+
"<|lb|>": 50345,
|
101 |
+
"<|ln|>": 50353,
|
102 |
+
"<|lo|>": 50336,
|
103 |
+
"<|lt|>": 50293,
|
104 |
+
"<|lv|>": 50301,
|
105 |
+
"<|mg|>": 50349,
|
106 |
+
"<|mi|>": 50295,
|
107 |
+
"<|mk|>": 50308,
|
108 |
+
"<|ml|>": 50296,
|
109 |
+
"<|mn|>": 50314,
|
110 |
+
"<|mr|>": 50320,
|
111 |
+
"<|ms|>": 50282,
|
112 |
+
"<|mt|>": 50343,
|
113 |
+
"<|my|>": 50346,
|
114 |
+
"<|ne|>": 50313,
|
115 |
+
"<|nl|>": 50271,
|
116 |
+
"<|nn|>": 50342,
|
117 |
+
"<|no|>": 50288,
|
118 |
+
"<|oc|>": 50328,
|
119 |
+
"<|pa|>": 50321,
|
120 |
+
"<|pl|>": 50269,
|
121 |
+
"<|ps|>": 50340,
|
122 |
+
"<|pt|>": 50267,
|
123 |
+
"<|ro|>": 50284,
|
124 |
+
"<|ru|>": 50263,
|
125 |
+
"<|sa|>": 50344,
|
126 |
+
"<|sd|>": 50332,
|
127 |
+
"<|si|>": 50322,
|
128 |
+
"<|sk|>": 50298,
|
129 |
+
"<|sl|>": 50305,
|
130 |
+
"<|sn|>": 50324,
|
131 |
+
"<|so|>": 50326,
|
132 |
+
"<|sq|>": 50317,
|
133 |
+
"<|sr|>": 50303,
|
134 |
+
"<|su|>": 50357,
|
135 |
+
"<|sv|>": 50273,
|
136 |
+
"<|sw|>": 50318,
|
137 |
+
"<|ta|>": 50287,
|
138 |
+
"<|te|>": 50299,
|
139 |
+
"<|tg|>": 50331,
|
140 |
+
"<|th|>": 50289,
|
141 |
+
"<|tk|>": 50341,
|
142 |
+
"<|tl|>": 50348,
|
143 |
+
"<|tr|>": 50268,
|
144 |
+
"<|tt|>": 50351,
|
145 |
+
"<|uk|>": 50280,
|
146 |
+
"<|ur|>": 50290,
|
147 |
+
"<|uz|>": 50337,
|
148 |
+
"<|vi|>": 50278,
|
149 |
+
"<|yi|>": 50335,
|
150 |
+
"<|yo|>": 50325,
|
151 |
+
"<|yue|>": 50358,
|
152 |
+
"<|zh|>": 50260
|
153 |
+
},
|
154 |
+
"max_initial_timestamp_index": 50,
|
155 |
+
"max_length": 448,
|
156 |
+
"no_timestamps_token_id": 50364,
|
157 |
+
"pad_token_id": 50257,
|
158 |
+
"prev_sot_token_id": 50362,
|
159 |
+
"return_timestamps": false,
|
160 |
+
"suppress_tokens": [
|
161 |
+
1,
|
162 |
+
2,
|
163 |
+
7,
|
164 |
+
8,
|
165 |
+
9,
|
166 |
+
10,
|
167 |
+
14,
|
168 |
+
25,
|
169 |
+
26,
|
170 |
+
27,
|
171 |
+
28,
|
172 |
+
29,
|
173 |
+
31,
|
174 |
+
58,
|
175 |
+
59,
|
176 |
+
60,
|
177 |
+
61,
|
178 |
+
62,
|
179 |
+
63,
|
180 |
+
90,
|
181 |
+
91,
|
182 |
+
92,
|
183 |
+
93,
|
184 |
+
359,
|
185 |
+
503,
|
186 |
+
522,
|
187 |
+
542,
|
188 |
+
873,
|
189 |
+
893,
|
190 |
+
902,
|
191 |
+
918,
|
192 |
+
922,
|
193 |
+
931,
|
194 |
+
1350,
|
195 |
+
1853,
|
196 |
+
1982,
|
197 |
+
2460,
|
198 |
+
2627,
|
199 |
+
3246,
|
200 |
+
3253,
|
201 |
+
3268,
|
202 |
+
3536,
|
203 |
+
3846,
|
204 |
+
3961,
|
205 |
+
4183,
|
206 |
+
4667,
|
207 |
+
6585,
|
208 |
+
6647,
|
209 |
+
7273,
|
210 |
+
9061,
|
211 |
+
9383,
|
212 |
+
10428,
|
213 |
+
10929,
|
214 |
+
11938,
|
215 |
+
12033,
|
216 |
+
12331,
|
217 |
+
12562,
|
218 |
+
13793,
|
219 |
+
14157,
|
220 |
+
14635,
|
221 |
+
15265,
|
222 |
+
15618,
|
223 |
+
16553,
|
224 |
+
16604,
|
225 |
+
18362,
|
226 |
+
18956,
|
227 |
+
20075,
|
228 |
+
21675,
|
229 |
+
22520,
|
230 |
+
26130,
|
231 |
+
26161,
|
232 |
+
26435,
|
233 |
+
28279,
|
234 |
+
29464,
|
235 |
+
31650,
|
236 |
+
32302,
|
237 |
+
32470,
|
238 |
+
36865,
|
239 |
+
42863,
|
240 |
+
47425,
|
241 |
+
49870,
|
242 |
+
50254,
|
243 |
+
50258,
|
244 |
+
50359,
|
245 |
+
50360,
|
246 |
+
50361,
|
247 |
+
50362,
|
248 |
+
50363
|
249 |
+
],
|
250 |
+
"task_to_id": {
|
251 |
+
"transcribe": 50360,
|
252 |
+
"translate": 50359
|
253 |
+
},
|
254 |
+
"transformers_version": "4.41.2"
|
255 |
+
}
|
distil-large-v3-init/merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
distil-large-v3-init/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:34de595c8d40aa1c243b11d5b6d2b4d282e3c5aaa46180b45bf90b6c33d6e924
|
3 |
+
size 3025686376
|
distil-large-v3-init/normalizer.json
ADDED
@@ -0,0 +1,1742 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"accessorise": "accessorize",
|
3 |
+
"accessorised": "accessorized",
|
4 |
+
"accessorises": "accessorizes",
|
5 |
+
"accessorising": "accessorizing",
|
6 |
+
"acclimatisation": "acclimatization",
|
7 |
+
"acclimatise": "acclimatize",
|
8 |
+
"acclimatised": "acclimatized",
|
9 |
+
"acclimatises": "acclimatizes",
|
10 |
+
"acclimatising": "acclimatizing",
|
11 |
+
"accoutrements": "accouterments",
|
12 |
+
"aeon": "eon",
|
13 |
+
"aeons": "eons",
|
14 |
+
"aerogramme": "aerogram",
|
15 |
+
"aerogrammes": "aerograms",
|
16 |
+
"aeroplane": "airplane",
|
17 |
+
"aeroplanes": "airplanes",
|
18 |
+
"aesthete": "esthete",
|
19 |
+
"aesthetes": "esthetes",
|
20 |
+
"aesthetic": "esthetic",
|
21 |
+
"aesthetically": "esthetically",
|
22 |
+
"aesthetics": "esthetics",
|
23 |
+
"aetiology": "etiology",
|
24 |
+
"ageing": "aging",
|
25 |
+
"aggrandisement": "aggrandizement",
|
26 |
+
"agonise": "agonize",
|
27 |
+
"agonised": "agonized",
|
28 |
+
"agonises": "agonizes",
|
29 |
+
"agonising": "agonizing",
|
30 |
+
"agonisingly": "agonizingly",
|
31 |
+
"almanack": "almanac",
|
32 |
+
"almanacks": "almanacs",
|
33 |
+
"aluminium": "aluminum",
|
34 |
+
"amortisable": "amortizable",
|
35 |
+
"amortisation": "amortization",
|
36 |
+
"amortisations": "amortizations",
|
37 |
+
"amortise": "amortize",
|
38 |
+
"amortised": "amortized",
|
39 |
+
"amortises": "amortizes",
|
40 |
+
"amortising": "amortizing",
|
41 |
+
"amphitheatre": "amphitheater",
|
42 |
+
"amphitheatres": "amphitheaters",
|
43 |
+
"anaemia": "anemia",
|
44 |
+
"anaemic": "anemic",
|
45 |
+
"anaesthesia": "anesthesia",
|
46 |
+
"anaesthetic": "anesthetic",
|
47 |
+
"anaesthetics": "anesthetics",
|
48 |
+
"anaesthetise": "anesthetize",
|
49 |
+
"anaesthetised": "anesthetized",
|
50 |
+
"anaesthetises": "anesthetizes",
|
51 |
+
"anaesthetising": "anesthetizing",
|
52 |
+
"anaesthetist": "anesthetist",
|
53 |
+
"anaesthetists": "anesthetists",
|
54 |
+
"anaesthetize": "anesthetize",
|
55 |
+
"anaesthetized": "anesthetized",
|
56 |
+
"anaesthetizes": "anesthetizes",
|
57 |
+
"anaesthetizing": "anesthetizing",
|
58 |
+
"analogue": "analog",
|
59 |
+
"analogues": "analogs",
|
60 |
+
"analyse": "analyze",
|
61 |
+
"analysed": "analyzed",
|
62 |
+
"analyses": "analyzes",
|
63 |
+
"analysing": "analyzing",
|
64 |
+
"anglicise": "anglicize",
|
65 |
+
"anglicised": "anglicized",
|
66 |
+
"anglicises": "anglicizes",
|
67 |
+
"anglicising": "anglicizing",
|
68 |
+
"annualised": "annualized",
|
69 |
+
"antagonise": "antagonize",
|
70 |
+
"antagonised": "antagonized",
|
71 |
+
"antagonises": "antagonizes",
|
72 |
+
"antagonising": "antagonizing",
|
73 |
+
"apologise": "apologize",
|
74 |
+
"apologised": "apologized",
|
75 |
+
"apologises": "apologizes",
|
76 |
+
"apologising": "apologizing",
|
77 |
+
"appal": "appall",
|
78 |
+
"appals": "appalls",
|
79 |
+
"appetiser": "appetizer",
|
80 |
+
"appetisers": "appetizers",
|
81 |
+
"appetising": "appetizing",
|
82 |
+
"appetisingly": "appetizingly",
|
83 |
+
"arbour": "arbor",
|
84 |
+
"arbours": "arbors",
|
85 |
+
"archaeologically": "archeologically",
|
86 |
+
"archaeologist": "archeologist",
|
87 |
+
"archaeologists": "archeologists",
|
88 |
+
"archaeology": "archeology</span>",
|
89 |
+
"archeological": "archaeological",
|
90 |
+
"ardour": "ardor",
|
91 |
+
"armour": "armor",
|
92 |
+
"armoured": "armored",
|
93 |
+
"armourer": "armorer",
|
94 |
+
"armourers": "armorers",
|
95 |
+
"armouries": "armories",
|
96 |
+
"armoury": "armory",
|
97 |
+
"artefact": "artifact",
|
98 |
+
"artefacts": "artifacts",
|
99 |
+
"authorise": "authorize",
|
100 |
+
"authorised": "authorized",
|
101 |
+
"authorises": "authorizes",
|
102 |
+
"authorising": "authorizing",
|
103 |
+
"axe": "ax",
|
104 |
+
"backpedalled": "backpedaled",
|
105 |
+
"backpedalling": "backpedaling",
|
106 |
+
"bannister": "banister",
|
107 |
+
"bannisters": "banisters",
|
108 |
+
"baptise": "baptize",
|
109 |
+
"baptised": "baptized",
|
110 |
+
"baptises": "baptizes",
|
111 |
+
"baptising": "baptizing",
|
112 |
+
"bastardise": "bastardize",
|
113 |
+
"bastardised": "bastardized",
|
114 |
+
"bastardises": "bastardizes",
|
115 |
+
"bastardising": "bastardizing",
|
116 |
+
"battleax": "battleaxe",
|
117 |
+
"baulk": "balk",
|
118 |
+
"baulked": "balked",
|
119 |
+
"baulking": "balking",
|
120 |
+
"baulks": "balks",
|
121 |
+
"bedevilled": "bedeviled",
|
122 |
+
"bedevilling": "bedeviling",
|
123 |
+
"behaviour": "behavior",
|
124 |
+
"behavioural": "behavioral",
|
125 |
+
"behaviourism": "behaviorism",
|
126 |
+
"behaviourist": "behaviorist",
|
127 |
+
"behaviourists": "behaviorists",
|
128 |
+
"behaviours": "behaviors",
|
129 |
+
"behove": "behoove",
|
130 |
+
"behoved": "behooved",
|
131 |
+
"behoves": "behooves",
|
132 |
+
"bejewelled": "bejeweled",
|
133 |
+
"belabour": "belabor",
|
134 |
+
"belaboured": "belabored",
|
135 |
+
"belabouring": "belaboring",
|
136 |
+
"belabours": "belabors",
|
137 |
+
"bevelled": "beveled",
|
138 |
+
"bevvies": "bevies",
|
139 |
+
"bevvy": "bevy",
|
140 |
+
"biassed": "biased",
|
141 |
+
"biassing": "biasing",
|
142 |
+
"bingeing": "binging",
|
143 |
+
"bougainvillaea": "bougainvillea",
|
144 |
+
"bougainvillaeas": "bougainvilleas",
|
145 |
+
"bowdlerise": "bowdlerize",
|
146 |
+
"bowdlerised": "bowdlerized",
|
147 |
+
"bowdlerises": "bowdlerizes",
|
148 |
+
"bowdlerising": "bowdlerizing",
|
149 |
+
"breathalyse": "breathalyze",
|
150 |
+
"breathalysed": "breathalyzed",
|
151 |
+
"breathalyser": "breathalyzer",
|
152 |
+
"breathalysers": "breathalyzers",
|
153 |
+
"breathalyses": "breathalyzes",
|
154 |
+
"breathalysing": "breathalyzing",
|
155 |
+
"brutalise": "brutalize",
|
156 |
+
"brutalised": "brutalized",
|
157 |
+
"brutalises": "brutalizes",
|
158 |
+
"brutalising": "brutalizing",
|
159 |
+
"busses": "buses",
|
160 |
+
"bussing": "busing",
|
161 |
+
"caesarean": "cesarean",
|
162 |
+
"caesareans": "cesareans",
|
163 |
+
"calibre": "caliber",
|
164 |
+
"calibres": "calibers",
|
165 |
+
"calliper": "caliper",
|
166 |
+
"callipers": "calipers",
|
167 |
+
"callisthenics": "calisthenics",
|
168 |
+
"canalise": "canalize",
|
169 |
+
"canalised": "canalized",
|
170 |
+
"canalises": "canalizes",
|
171 |
+
"canalising": "canalizing",
|
172 |
+
"cancelation": "cancellation",
|
173 |
+
"cancelations": "cancellations",
|
174 |
+
"cancelled": "canceled",
|
175 |
+
"cancelling": "canceling",
|
176 |
+
"candour": "candor",
|
177 |
+
"cannibalise": "cannibalize",
|
178 |
+
"cannibalised": "cannibalized",
|
179 |
+
"cannibalises": "cannibalizes",
|
180 |
+
"cannibalising": "cannibalizing",
|
181 |
+
"canonise": "canonize",
|
182 |
+
"canonised": "canonized",
|
183 |
+
"canonises": "canonizes",
|
184 |
+
"canonising": "canonizing",
|
185 |
+
"capitalise": "capitalize",
|
186 |
+
"capitalised": "capitalized",
|
187 |
+
"capitalises": "capitalizes",
|
188 |
+
"capitalising": "capitalizing",
|
189 |
+
"caramelise": "caramelize",
|
190 |
+
"caramelised": "caramelized",
|
191 |
+
"caramelises": "caramelizes",
|
192 |
+
"caramelising": "caramelizing",
|
193 |
+
"carbonise": "carbonize",
|
194 |
+
"carbonised": "carbonized",
|
195 |
+
"carbonises": "carbonizes",
|
196 |
+
"carbonising": "carbonizing",
|
197 |
+
"carolled": "caroled",
|
198 |
+
"carolling": "caroling",
|
199 |
+
"catalogue": "catalog",
|
200 |
+
"catalogued": "cataloged",
|
201 |
+
"catalogues": "catalogs",
|
202 |
+
"cataloguing": "cataloging",
|
203 |
+
"catalyse": "catalyze",
|
204 |
+
"catalysed": "catalyzed",
|
205 |
+
"catalyses": "catalyzes",
|
206 |
+
"catalysing": "catalyzing",
|
207 |
+
"categorise": "categorize",
|
208 |
+
"categorised": "categorized",
|
209 |
+
"categorises": "categorizes",
|
210 |
+
"categorising": "categorizing",
|
211 |
+
"cauterise": "cauterize",
|
212 |
+
"cauterised": "cauterized",
|
213 |
+
"cauterises": "cauterizes",
|
214 |
+
"cauterising": "cauterizing",
|
215 |
+
"cavilled": "caviled",
|
216 |
+
"cavilling": "caviling",
|
217 |
+
"centigramme": "centigram",
|
218 |
+
"centigrammes": "centigrams",
|
219 |
+
"centilitre": "centiliter",
|
220 |
+
"centilitres": "centiliters",
|
221 |
+
"centimetre": "centimeter",
|
222 |
+
"centimetres": "centimeters",
|
223 |
+
"centralise": "centralize",
|
224 |
+
"centralised": "centralized",
|
225 |
+
"centralises": "centralizes",
|
226 |
+
"centralising": "centralizing",
|
227 |
+
"centre": "center",
|
228 |
+
"centred": "centered",
|
229 |
+
"centrefold": "centerfold",
|
230 |
+
"centrefolds": "centerfolds",
|
231 |
+
"centrepiece": "centerpiece",
|
232 |
+
"centrepieces": "centerpieces",
|
233 |
+
"centres": "centers",
|
234 |
+
"channelled": "channeled",
|
235 |
+
"channelling": "channeling",
|
236 |
+
"characterise": "characterize",
|
237 |
+
"characterised": "characterized",
|
238 |
+
"characterises": "characterizes",
|
239 |
+
"characterising": "characterizing",
|
240 |
+
"cheque": "check",
|
241 |
+
"chequebook": "checkbook",
|
242 |
+
"chequebooks": "checkbooks",
|
243 |
+
"chequered": "checkered",
|
244 |
+
"cheques": "checks",
|
245 |
+
"chilli": "chili",
|
246 |
+
"chimaera": "chimera",
|
247 |
+
"chimaeras": "chimeras",
|
248 |
+
"chiselled": "chiseled",
|
249 |
+
"chiselling": "chiseling",
|
250 |
+
"circularise": "circularize",
|
251 |
+
"circularised": "circularized",
|
252 |
+
"circularises": "circularizes",
|
253 |
+
"circularising": "circularizing",
|
254 |
+
"civilise": "civilize",
|
255 |
+
"civilised": "civilized",
|
256 |
+
"civilises": "civilizes",
|
257 |
+
"civilising": "civilizing",
|
258 |
+
"clamour": "clamor",
|
259 |
+
"clamoured": "clamored",
|
260 |
+
"clamouring": "clamoring",
|
261 |
+
"clamours": "clamors",
|
262 |
+
"clangour": "clangor",
|
263 |
+
"clarinettist": "clarinetist",
|
264 |
+
"clarinettists": "clarinetists",
|
265 |
+
"collectivise": "collectivize",
|
266 |
+
"collectivised": "collectivized",
|
267 |
+
"collectivises": "collectivizes",
|
268 |
+
"collectivising": "collectivizing",
|
269 |
+
"colonisation": "colonization",
|
270 |
+
"colonise": "colonize",
|
271 |
+
"colonised": "colonized",
|
272 |
+
"coloniser": "colonizer",
|
273 |
+
"colonisers": "colonizers",
|
274 |
+
"colonises": "colonizes",
|
275 |
+
"colonising": "colonizing",
|
276 |
+
"colour": "color",
|
277 |
+
"colourant": "colorant",
|
278 |
+
"colourants": "colorants",
|
279 |
+
"coloured": "colored",
|
280 |
+
"coloureds": "coloreds",
|
281 |
+
"colourful": "colorful",
|
282 |
+
"colourfully": "colorfully",
|
283 |
+
"colouring": "coloring",
|
284 |
+
"colourize": "colorize",
|
285 |
+
"colourized": "colorized",
|
286 |
+
"colourizes": "colorizes",
|
287 |
+
"colourizing": "colorizing",
|
288 |
+
"colourless": "colorless",
|
289 |
+
"colours": "colors",
|
290 |
+
"commercialise": "commercialize",
|
291 |
+
"commercialised": "commercialized",
|
292 |
+
"commercialises": "commercializes",
|
293 |
+
"commercialising": "commercializing",
|
294 |
+
"compartmentalise": "compartmentalize",
|
295 |
+
"compartmentalised": "compartmentalized",
|
296 |
+
"compartmentalises": "compartmentalizes",
|
297 |
+
"compartmentalising": "compartmentalizing",
|
298 |
+
"computerise": "computerize",
|
299 |
+
"computerised": "computerized",
|
300 |
+
"computerises": "computerizes",
|
301 |
+
"computerising": "computerizing",
|
302 |
+
"conceptualise": "conceptualize",
|
303 |
+
"conceptualised": "conceptualized",
|
304 |
+
"conceptualises": "conceptualizes",
|
305 |
+
"conceptualising": "conceptualizing",
|
306 |
+
"connexion": "connection",
|
307 |
+
"connexions": "connections",
|
308 |
+
"contextualise": "contextualize",
|
309 |
+
"contextualised": "contextualized",
|
310 |
+
"contextualises": "contextualizes",
|
311 |
+
"contextualising": "contextualizing",
|
312 |
+
"cosier": "cozier",
|
313 |
+
"cosies": "cozies",
|
314 |
+
"cosiest": "coziest",
|
315 |
+
"cosily": "cozily",
|
316 |
+
"cosiness": "coziness",
|
317 |
+
"cosy": "cozy",
|
318 |
+
"councillor": "councilor",
|
319 |
+
"councillors": "councilors",
|
320 |
+
"counselled": "counseled",
|
321 |
+
"counselling": "counseling",
|
322 |
+
"counsellor": "counselor",
|
323 |
+
"counsellors": "counselors",
|
324 |
+
"crenelated": "crenellated",
|
325 |
+
"criminalise": "criminalize",
|
326 |
+
"criminalised": "criminalized",
|
327 |
+
"criminalises": "criminalizes",
|
328 |
+
"criminalising": "criminalizing",
|
329 |
+
"criticise": "criticize",
|
330 |
+
"criticised": "criticized",
|
331 |
+
"criticises": "criticizes",
|
332 |
+
"criticising": "criticizing",
|
333 |
+
"crueller": "crueler",
|
334 |
+
"cruellest": "cruelest",
|
335 |
+
"crystallisation": "crystallization",
|
336 |
+
"crystallise": "crystallize",
|
337 |
+
"crystallised": "crystallized",
|
338 |
+
"crystallises": "crystallizes",
|
339 |
+
"crystallising": "crystallizing",
|
340 |
+
"cudgelled": "cudgeled",
|
341 |
+
"cudgelling": "cudgeling",
|
342 |
+
"customise": "customize",
|
343 |
+
"customised": "customized",
|
344 |
+
"customises": "customizes",
|
345 |
+
"customising": "customizing",
|
346 |
+
"cypher": "cipher",
|
347 |
+
"cyphers": "ciphers",
|
348 |
+
"decentralisation": "decentralization",
|
349 |
+
"decentralise": "decentralize",
|
350 |
+
"decentralised": "decentralized",
|
351 |
+
"decentralises": "decentralizes",
|
352 |
+
"decentralising": "decentralizing",
|
353 |
+
"decriminalisation": "decriminalization",
|
354 |
+
"decriminalise": "decriminalize",
|
355 |
+
"decriminalised": "decriminalized",
|
356 |
+
"decriminalises": "decriminalizes",
|
357 |
+
"decriminalising": "decriminalizing",
|
358 |
+
"defence": "defense",
|
359 |
+
"defenceless": "defenseless",
|
360 |
+
"defences": "defenses",
|
361 |
+
"dehumanisation": "dehumanization",
|
362 |
+
"dehumanise": "dehumanize",
|
363 |
+
"dehumanised": "dehumanized",
|
364 |
+
"dehumanises": "dehumanizes",
|
365 |
+
"dehumanising": "dehumanizing",
|
366 |
+
"demeanour": "demeanor",
|
367 |
+
"demilitarisation": "demilitarization",
|
368 |
+
"demilitarise": "demilitarize",
|
369 |
+
"demilitarised": "demilitarized",
|
370 |
+
"demilitarises": "demilitarizes",
|
371 |
+
"demilitarising": "demilitarizing",
|
372 |
+
"demobilisation": "demobilization",
|
373 |
+
"demobilise": "demobilize",
|
374 |
+
"demobilised": "demobilized",
|
375 |
+
"demobilises": "demobilizes",
|
376 |
+
"demobilising": "demobilizing",
|
377 |
+
"democratisation": "democratization",
|
378 |
+
"democratise": "democratize",
|
379 |
+
"democratised": "democratized",
|
380 |
+
"democratises": "democratizes",
|
381 |
+
"democratising": "democratizing",
|
382 |
+
"demonise": "demonize",
|
383 |
+
"demonised": "demonized",
|
384 |
+
"demonises": "demonizes",
|
385 |
+
"demonising": "demonizing",
|
386 |
+
"demoralisation": "demoralization",
|
387 |
+
"demoralise": "demoralize",
|
388 |
+
"demoralised": "demoralized",
|
389 |
+
"demoralises": "demoralizes",
|
390 |
+
"demoralising": "demoralizing",
|
391 |
+
"denationalisation": "denationalization",
|
392 |
+
"denationalise": "denationalize",
|
393 |
+
"denationalised": "denationalized",
|
394 |
+
"denationalises": "denationalizes",
|
395 |
+
"denationalising": "denationalizing",
|
396 |
+
"deodorise": "deodorize",
|
397 |
+
"deodorised": "deodorized",
|
398 |
+
"deodorises": "deodorizes",
|
399 |
+
"deodorising": "deodorizing",
|
400 |
+
"depersonalise": "depersonalize",
|
401 |
+
"depersonalised": "depersonalized",
|
402 |
+
"depersonalises": "depersonalizes",
|
403 |
+
"depersonalising": "depersonalizing",
|
404 |
+
"deputise": "deputize",
|
405 |
+
"deputised": "deputized",
|
406 |
+
"deputises": "deputizes",
|
407 |
+
"deputising": "deputizing",
|
408 |
+
"desensitisation": "desensitization",
|
409 |
+
"desensitise": "desensitize",
|
410 |
+
"desensitised": "desensitized",
|
411 |
+
"desensitises": "desensitizes",
|
412 |
+
"desensitising": "desensitizing",
|
413 |
+
"destabilisation": "destabilization",
|
414 |
+
"destabilise": "destabilize",
|
415 |
+
"destabilised": "destabilized",
|
416 |
+
"destabilises": "destabilizes",
|
417 |
+
"destabilising": "destabilizing",
|
418 |
+
"dialled": "dialed",
|
419 |
+
"dialling": "dialing",
|
420 |
+
"dialogue": "dialog",
|
421 |
+
"dialogues": "dialogs",
|
422 |
+
"diarrhoea": "diarrhea",
|
423 |
+
"digitise": "digitize",
|
424 |
+
"digitised": "digitized",
|
425 |
+
"digitises": "digitizes",
|
426 |
+
"digitising": "digitizing",
|
427 |
+
"disc": "disk",
|
428 |
+
"discolour": "discolor",
|
429 |
+
"discoloured": "discolored",
|
430 |
+
"discolouring": "discoloring",
|
431 |
+
"discolours": "discolors",
|
432 |
+
"discs": "disks",
|
433 |
+
"disembowelled": "disemboweled",
|
434 |
+
"disembowelling": "disemboweling",
|
435 |
+
"disfavour": "disfavor",
|
436 |
+
"dishevelled": "disheveled",
|
437 |
+
"dishonour": "dishonor",
|
438 |
+
"dishonourable": "dishonorable",
|
439 |
+
"dishonourably": "dishonorably",
|
440 |
+
"dishonoured": "dishonored",
|
441 |
+
"dishonouring": "dishonoring",
|
442 |
+
"dishonours": "dishonors",
|
443 |
+
"disorganisation": "disorganization",
|
444 |
+
"disorganised": "disorganized",
|
445 |
+
"distil": "distill",
|
446 |
+
"distils": "distills",
|
447 |
+
"dramatisation": "dramatization",
|
448 |
+
"dramatisations": "dramatizations",
|
449 |
+
"dramatise": "dramatize",
|
450 |
+
"dramatised": "dramatized",
|
451 |
+
"dramatises": "dramatizes",
|
452 |
+
"dramatising": "dramatizing",
|
453 |
+
"draught": "draft",
|
454 |
+
"draughtboard": "draftboard",
|
455 |
+
"draughtboards": "draftboards",
|
456 |
+
"draughtier": "draftier",
|
457 |
+
"draughtiest": "draftiest",
|
458 |
+
"draughts": "drafts",
|
459 |
+
"draughtsman": "draftsman",
|
460 |
+
"draughtsmanship": "draftsmanship",
|
461 |
+
"draughtsmen": "draftsmen",
|
462 |
+
"draughtswoman": "draftswoman",
|
463 |
+
"draughtswomen": "draftswomen",
|
464 |
+
"draughty": "drafty",
|
465 |
+
"drivelled": "driveled",
|
466 |
+
"drivelling": "driveling",
|
467 |
+
"duelled": "dueled",
|
468 |
+
"duelling": "dueling",
|
469 |
+
"economise": "economize",
|
470 |
+
"economised": "economized",
|
471 |
+
"economises": "economizes",
|
472 |
+
"economising": "economizing",
|
473 |
+
"editorialise": "editorialize",
|
474 |
+
"editorialised": "editorialized",
|
475 |
+
"editorialises": "editorializes",
|
476 |
+
"editorialising": "editorializing",
|
477 |
+
"edoema": "edema",
|
478 |
+
"empathise": "empathize",
|
479 |
+
"empathised": "empathized",
|
480 |
+
"empathises": "empathizes",
|
481 |
+
"empathising": "empathizing",
|
482 |
+
"emphasise": "emphasize",
|
483 |
+
"emphasised": "emphasized",
|
484 |
+
"emphasises": "emphasizes",
|
485 |
+
"emphasising": "emphasizing",
|
486 |
+
"enamelled": "enameled",
|
487 |
+
"enamelling": "enameling",
|
488 |
+
"enamoured": "enamored",
|
489 |
+
"encyclopaedia": "encyclopedia",
|
490 |
+
"encyclopaedias": "encyclopedias",
|
491 |
+
"encyclopaedic": "encyclopedic",
|
492 |
+
"endeavour": "endeavor",
|
493 |
+
"endeavoured": "endeavored",
|
494 |
+
"endeavouring": "endeavoring",
|
495 |
+
"endeavours": "endeavors",
|
496 |
+
"energise": "energize",
|
497 |
+
"energised": "energized",
|
498 |
+
"energises": "energizes",
|
499 |
+
"energising": "energizing",
|
500 |
+
"enrol": "enroll",
|
501 |
+
"enrols": "enrolls",
|
502 |
+
"enthral": "enthrall",
|
503 |
+
"enthrals": "enthralls",
|
504 |
+
"epaulette": "epaulet",
|
505 |
+
"epaulettes": "epaulets",
|
506 |
+
"epicentre": "epicenter",
|
507 |
+
"epicentres": "epicenters",
|
508 |
+
"epilogue": "epilog",
|
509 |
+
"epilogues": "epilogs",
|
510 |
+
"epitomise": "epitomize",
|
511 |
+
"epitomised": "epitomized",
|
512 |
+
"epitomises": "epitomizes",
|
513 |
+
"epitomising": "epitomizing",
|
514 |
+
"equalisation": "equalization",
|
515 |
+
"equalise": "equalize",
|
516 |
+
"equalised": "equalized",
|
517 |
+
"equaliser": "equalizer",
|
518 |
+
"equalisers": "equalizers",
|
519 |
+
"equalises": "equalizes",
|
520 |
+
"equalising": "equalizing",
|
521 |
+
"eulogise": "eulogize",
|
522 |
+
"eulogised": "eulogized",
|
523 |
+
"eulogises": "eulogizes",
|
524 |
+
"eulogising": "eulogizing",
|
525 |
+
"evangelise": "evangelize",
|
526 |
+
"evangelised": "evangelized",
|
527 |
+
"evangelises": "evangelizes",
|
528 |
+
"evangelising": "evangelizing",
|
529 |
+
"exorcise": "exorcize",
|
530 |
+
"exorcised": "exorcized",
|
531 |
+
"exorcises": "exorcizes",
|
532 |
+
"exorcising": "exorcizing",
|
533 |
+
"extemporisation": "extemporization",
|
534 |
+
"extemporise": "extemporize",
|
535 |
+
"extemporised": "extemporized",
|
536 |
+
"extemporises": "extemporizes",
|
537 |
+
"extemporising": "extemporizing",
|
538 |
+
"externalisation": "externalization",
|
539 |
+
"externalisations": "externalizations",
|
540 |
+
"externalise": "externalize",
|
541 |
+
"externalised": "externalized",
|
542 |
+
"externalises": "externalizes",
|
543 |
+
"externalising": "externalizing",
|
544 |
+
"factorise": "factorize",
|
545 |
+
"factorised": "factorized",
|
546 |
+
"factorises": "factorizes",
|
547 |
+
"factorising": "factorizing",
|
548 |
+
"faecal": "fecal",
|
549 |
+
"faeces": "feces",
|
550 |
+
"familiarisation": "familiarization",
|
551 |
+
"familiarise": "familiarize",
|
552 |
+
"familiarised": "familiarized",
|
553 |
+
"familiarises": "familiarizes",
|
554 |
+
"familiarising": "familiarizing",
|
555 |
+
"fantasise": "fantasize",
|
556 |
+
"fantasised": "fantasized",
|
557 |
+
"fantasises": "fantasizes",
|
558 |
+
"fantasising": "fantasizing",
|
559 |
+
"favour": "favor",
|
560 |
+
"favourable": "favorable",
|
561 |
+
"favourably": "favorably",
|
562 |
+
"favoured": "favored",
|
563 |
+
"favouring": "favoring",
|
564 |
+
"favourite": "favorite",
|
565 |
+
"favourites": "favorites",
|
566 |
+
"favouritism": "favoritism",
|
567 |
+
"favours": "favors",
|
568 |
+
"feminise": "feminize",
|
569 |
+
"feminised": "feminized",
|
570 |
+
"feminises": "feminizes",
|
571 |
+
"feminising": "feminizing",
|
572 |
+
"fertilisation": "fertilization",
|
573 |
+
"fertilise": "fertilize",
|
574 |
+
"fertilised": "fertilized",
|
575 |
+
"fertiliser": "fertilizer",
|
576 |
+
"fertilisers": "fertilizers",
|
577 |
+
"fertilises": "fertilizes",
|
578 |
+
"fertilising": "fertilizing",
|
579 |
+
"fervour": "fervor",
|
580 |
+
"fibre": "fiber",
|
581 |
+
"fibreglass": "fiberglass",
|
582 |
+
"fibres": "fibers",
|
583 |
+
"fictionalisation": "fictionalization",
|
584 |
+
"fictionalisations": "fictionalizations",
|
585 |
+
"fictionalise": "fictionalize",
|
586 |
+
"fictionalised": "fictionalized",
|
587 |
+
"fictionalises": "fictionalizes",
|
588 |
+
"fictionalising": "fictionalizing",
|
589 |
+
"fillet": "filet",
|
590 |
+
"filleted": "fileted",
|
591 |
+
"filleting": "fileting",
|
592 |
+
"fillets": "filets",
|
593 |
+
"finalisation": "finalization",
|
594 |
+
"finalise": "finalize",
|
595 |
+
"finalised": "finalized",
|
596 |
+
"finalises": "finalizes",
|
597 |
+
"finalising": "finalizing",
|
598 |
+
"flautist": "flutist",
|
599 |
+
"flautists": "flutists",
|
600 |
+
"flavour": "flavor",
|
601 |
+
"flavoured": "flavored",
|
602 |
+
"flavouring": "flavoring",
|
603 |
+
"flavourings": "flavorings",
|
604 |
+
"flavourless": "flavorless",
|
605 |
+
"flavours": "flavors",
|
606 |
+
"flavoursome": "flavorsome",
|
607 |
+
"flyer / flier": "flier / flyer",
|
608 |
+
"foetal": "fetal",
|
609 |
+
"foetid": "fetid",
|
610 |
+
"foetus": "fetus",
|
611 |
+
"foetuses": "fetuses",
|
612 |
+
"formalisation": "formalization",
|
613 |
+
"formalise": "formalize",
|
614 |
+
"formalised": "formalized",
|
615 |
+
"formalises": "formalizes",
|
616 |
+
"formalising": "formalizing",
|
617 |
+
"fossilisation": "fossilization",
|
618 |
+
"fossilise": "fossilize",
|
619 |
+
"fossilised": "fossilized",
|
620 |
+
"fossilises": "fossilizes",
|
621 |
+
"fossilising": "fossilizing",
|
622 |
+
"fraternisation": "fraternization",
|
623 |
+
"fraternise": "fraternize",
|
624 |
+
"fraternised": "fraternized",
|
625 |
+
"fraternises": "fraternizes",
|
626 |
+
"fraternising": "fraternizing",
|
627 |
+
"fulfil": "fulfill",
|
628 |
+
"fulfilment": "fulfillment",
|
629 |
+
"fulfils": "fulfills",
|
630 |
+
"funnelled": "funneled",
|
631 |
+
"funnelling": "funneling",
|
632 |
+
"gage": "gauge",
|
633 |
+
"gaged": "gauged",
|
634 |
+
"gages": "gauges",
|
635 |
+
"gaging": "gauging",
|
636 |
+
"galvanise": "galvanize",
|
637 |
+
"galvanised": "galvanized",
|
638 |
+
"galvanises": "galvanizes",
|
639 |
+
"galvanising": "galvanizing",
|
640 |
+
"gambolled": "gamboled",
|
641 |
+
"gambolling": "gamboling",
|
642 |
+
"gaol": "jail",
|
643 |
+
"gaolbird": "jailbird",
|
644 |
+
"gaolbirds": "jailbirds",
|
645 |
+
"gaolbreak": "jailbreak",
|
646 |
+
"gaolbreaks": "jailbreaks",
|
647 |
+
"gaoled": "jailed",
|
648 |
+
"gaoler": "jailer",
|
649 |
+
"gaolers": "jailers",
|
650 |
+
"gaoling": "jailing",
|
651 |
+
"gaols": "jails",
|
652 |
+
"gasses": "gases",
|
653 |
+
"generalisation": "generalization",
|
654 |
+
"generalisations": "generalizations",
|
655 |
+
"generalise": "generalize",
|
656 |
+
"generalised": "generalized",
|
657 |
+
"generalises": "generalizes",
|
658 |
+
"generalising": "generalizing",
|
659 |
+
"ghettoise": "ghettoize",
|
660 |
+
"ghettoised": "ghettoized",
|
661 |
+
"ghettoises": "ghettoizes",
|
662 |
+
"ghettoising": "ghettoizing",
|
663 |
+
"gipsies": "gypsies",
|
664 |
+
"glamor": "glamour",
|
665 |
+
"glamorise": "glamorize",
|
666 |
+
"glamorised": "glamorized",
|
667 |
+
"glamorises": "glamorizes",
|
668 |
+
"glamorising": "glamorizing",
|
669 |
+
"globalisation": "globalization",
|
670 |
+
"globalise": "globalize",
|
671 |
+
"globalised": "globalized",
|
672 |
+
"globalises": "globalizes",
|
673 |
+
"globalising": "globalizing",
|
674 |
+
"glueing": "gluing",
|
675 |
+
"goitre": "goiter",
|
676 |
+
"goitres": "goiters",
|
677 |
+
"gonorrhoea": "gonorrhea",
|
678 |
+
"gramme": "gram",
|
679 |
+
"grammes": "grams",
|
680 |
+
"gravelled": "graveled",
|
681 |
+
"grey": "gray",
|
682 |
+
"greyed": "grayed",
|
683 |
+
"greying": "graying",
|
684 |
+
"greyish": "grayish",
|
685 |
+
"greyness": "grayness",
|
686 |
+
"greys": "grays",
|
687 |
+
"grovelled": "groveled",
|
688 |
+
"grovelling": "groveling",
|
689 |
+
"groyne": "groin",
|
690 |
+
"groynes": "groins",
|
691 |
+
"gruelling": "grueling",
|
692 |
+
"gruellingly": "gruelingly",
|
693 |
+
"gryphon": "griffin",
|
694 |
+
"gryphons": "griffins",
|
695 |
+
"gynaecological": "gynecological",
|
696 |
+
"gynaecologist": "gynecologist",
|
697 |
+
"gynaecologists": "gynecologists",
|
698 |
+
"gynaecology": "gynecology",
|
699 |
+
"haematological": "hematological",
|
700 |
+
"haematologist": "hematologist",
|
701 |
+
"haematologists": "hematologists",
|
702 |
+
"haematology": "hematology",
|
703 |
+
"haemoglobin": "hemoglobin",
|
704 |
+
"haemophilia": "hemophilia",
|
705 |
+
"haemophiliac": "hemophiliac",
|
706 |
+
"haemophiliacs": "hemophiliacs",
|
707 |
+
"haemorrhage": "hemorrhage",
|
708 |
+
"haemorrhaged": "hemorrhaged",
|
709 |
+
"haemorrhages": "hemorrhages",
|
710 |
+
"haemorrhaging": "hemorrhaging",
|
711 |
+
"haemorrhoids": "hemorrhoids",
|
712 |
+
"harbour": "harbor",
|
713 |
+
"harboured": "harbored",
|
714 |
+
"harbouring": "harboring",
|
715 |
+
"harbours": "harbors",
|
716 |
+
"harmonisation": "harmonization",
|
717 |
+
"harmonise": "harmonize",
|
718 |
+
"harmonised": "harmonized",
|
719 |
+
"harmonises": "harmonizes",
|
720 |
+
"harmonising": "harmonizing",
|
721 |
+
"homoeopath": "homeopath",
|
722 |
+
"homoeopathic": "homeopathic",
|
723 |
+
"homoeopaths": "homeopaths",
|
724 |
+
"homoeopathy": "homeopathy",
|
725 |
+
"homogenise": "homogenize",
|
726 |
+
"homogenised": "homogenized",
|
727 |
+
"homogenises": "homogenizes",
|
728 |
+
"homogenising": "homogenizing",
|
729 |
+
"honour": "honor",
|
730 |
+
"honourable": "honorable",
|
731 |
+
"honourably": "honorably",
|
732 |
+
"honoured": "honored",
|
733 |
+
"honouring": "honoring",
|
734 |
+
"honours": "honors",
|
735 |
+
"hospitalisation": "hospitalization",
|
736 |
+
"hospitalise": "hospitalize",
|
737 |
+
"hospitalised": "hospitalized",
|
738 |
+
"hospitalises": "hospitalizes",
|
739 |
+
"hospitalising": "hospitalizing",
|
740 |
+
"humanise": "humanize",
|
741 |
+
"humanised": "humanized",
|
742 |
+
"humanises": "humanizes",
|
743 |
+
"humanising": "humanizing",
|
744 |
+
"humour": "humor",
|
745 |
+
"humoured": "humored",
|
746 |
+
"humouring": "humoring",
|
747 |
+
"humourless": "humorless",
|
748 |
+
"humours": "humors",
|
749 |
+
"hybridise": "hybridize",
|
750 |
+
"hybridised": "hybridized",
|
751 |
+
"hybridises": "hybridizes",
|
752 |
+
"hybridising": "hybridizing",
|
753 |
+
"hypnotise": "hypnotize",
|
754 |
+
"hypnotised": "hypnotized",
|
755 |
+
"hypnotises": "hypnotizes",
|
756 |
+
"hypnotising": "hypnotizing",
|
757 |
+
"hypothesise": "hypothesize",
|
758 |
+
"hypothesised": "hypothesized",
|
759 |
+
"hypothesises": "hypothesizes",
|
760 |
+
"hypothesising": "hypothesizing",
|
761 |
+
"idealisation": "idealization",
|
762 |
+
"idealise": "idealize",
|
763 |
+
"idealised": "idealized",
|
764 |
+
"idealises": "idealizes",
|
765 |
+
"idealising": "idealizing",
|
766 |
+
"idolise": "idolize",
|
767 |
+
"idolised": "idolized",
|
768 |
+
"idolises": "idolizes",
|
769 |
+
"idolising": "idolizing",
|
770 |
+
"immobilisation": "immobilization",
|
771 |
+
"immobilise": "immobilize",
|
772 |
+
"immobilised": "immobilized",
|
773 |
+
"immobiliser": "immobilizer",
|
774 |
+
"immobilisers": "immobilizers",
|
775 |
+
"immobilises": "immobilizes",
|
776 |
+
"immobilising": "immobilizing",
|
777 |
+
"immortalise": "immortalize",
|
778 |
+
"immortalised": "immortalized",
|
779 |
+
"immortalises": "immortalizes",
|
780 |
+
"immortalising": "immortalizing",
|
781 |
+
"immunisation": "immunization",
|
782 |
+
"immunise": "immunize",
|
783 |
+
"immunised": "immunized",
|
784 |
+
"immunises": "immunizes",
|
785 |
+
"immunising": "immunizing",
|
786 |
+
"impanelled": "impaneled",
|
787 |
+
"impanelling": "impaneling",
|
788 |
+
"imperilled": "imperiled",
|
789 |
+
"imperilling": "imperiling",
|
790 |
+
"individualise": "individualize",
|
791 |
+
"individualised": "individualized",
|
792 |
+
"individualises": "individualizes",
|
793 |
+
"individualising": "individualizing",
|
794 |
+
"industrialise": "industrialize",
|
795 |
+
"industrialised": "industrialized",
|
796 |
+
"industrialises": "industrializes",
|
797 |
+
"industrialising": "industrializing",
|
798 |
+
"inflexion": "inflection",
|
799 |
+
"inflexions": "inflections",
|
800 |
+
"initialise": "initialize",
|
801 |
+
"initialised": "initialized",
|
802 |
+
"initialises": "initializes",
|
803 |
+
"initialising": "initializing",
|
804 |
+
"initialled": "initialed",
|
805 |
+
"initialling": "initialing",
|
806 |
+
"instal": "install",
|
807 |
+
"instalment": "installment",
|
808 |
+
"instalments": "installments",
|
809 |
+
"instals": "installs",
|
810 |
+
"instil": "instill",
|
811 |
+
"instils": "instills",
|
812 |
+
"institutionalisation": "institutionalization",
|
813 |
+
"institutionalise": "institutionalize",
|
814 |
+
"institutionalised": "institutionalized",
|
815 |
+
"institutionalises": "institutionalizes",
|
816 |
+
"institutionalising": "institutionalizing",
|
817 |
+
"intellectualise": "intellectualize",
|
818 |
+
"intellectualised": "intellectualized",
|
819 |
+
"intellectualises": "intellectualizes",
|
820 |
+
"intellectualising": "intellectualizing",
|
821 |
+
"internalisation": "internalization",
|
822 |
+
"internalise": "internalize",
|
823 |
+
"internalised": "internalized",
|
824 |
+
"internalises": "internalizes",
|
825 |
+
"internalising": "internalizing",
|
826 |
+
"internationalisation": "internationalization",
|
827 |
+
"internationalise": "internationalize",
|
828 |
+
"internationalised": "internationalized",
|
829 |
+
"internationalises": "internationalizes",
|
830 |
+
"internationalising": "internationalizing",
|
831 |
+
"ionisation": "ionization",
|
832 |
+
"ionise": "ionize",
|
833 |
+
"ionised": "ionized",
|
834 |
+
"ioniser": "ionizer",
|
835 |
+
"ionisers": "ionizers",
|
836 |
+
"ionises": "ionizes",
|
837 |
+
"ionising": "ionizing",
|
838 |
+
"italicise": "italicize",
|
839 |
+
"italicised": "italicized",
|
840 |
+
"italicises": "italicizes",
|
841 |
+
"italicising": "italicizing",
|
842 |
+
"itemise": "itemize",
|
843 |
+
"itemised": "itemized",
|
844 |
+
"itemises": "itemizes",
|
845 |
+
"itemising": "itemizing",
|
846 |
+
"jeopardise": "jeopardize",
|
847 |
+
"jeopardised": "jeopardized",
|
848 |
+
"jeopardises": "jeopardizes",
|
849 |
+
"jeopardising": "jeopardizing",
|
850 |
+
"jewelled": "jeweled",
|
851 |
+
"jeweller": "jeweler",
|
852 |
+
"jewellers": "jewelers",
|
853 |
+
"jewellery": "jewelry",
|
854 |
+
"judgement": "judgment",
|
855 |
+
"kilogramme": "kilogram",
|
856 |
+
"kilogrammes": "kilograms",
|
857 |
+
"kilometre": "kilometer",
|
858 |
+
"kilometres": "kilometers",
|
859 |
+
"labelled": "labeled",
|
860 |
+
"labelling": "labeling",
|
861 |
+
"labour": "labor",
|
862 |
+
"laboured": "labored",
|
863 |
+
"labourer": "laborer",
|
864 |
+
"labourers": "laborers",
|
865 |
+
"labouring": "laboring",
|
866 |
+
"labours": "labors",
|
867 |
+
"lacklustre": "lackluster",
|
868 |
+
"legalisation": "legalization",
|
869 |
+
"legalise": "legalize",
|
870 |
+
"legalised": "legalized",
|
871 |
+
"legalises": "legalizes",
|
872 |
+
"legalising": "legalizing",
|
873 |
+
"legitimise": "legitimize",
|
874 |
+
"legitimised": "legitimized",
|
875 |
+
"legitimises": "legitimizes",
|
876 |
+
"legitimising": "legitimizing",
|
877 |
+
"leukaemia": "leukemia",
|
878 |
+
"levelled": "leveled",
|
879 |
+
"leveller": "leveler",
|
880 |
+
"levellers": "levelers",
|
881 |
+
"levelling": "leveling",
|
882 |
+
"libelled": "libeled",
|
883 |
+
"libelling": "libeling",
|
884 |
+
"libellous": "libelous",
|
885 |
+
"liberalisation": "liberalization",
|
886 |
+
"liberalise": "liberalize",
|
887 |
+
"liberalised": "liberalized",
|
888 |
+
"liberalises": "liberalizes",
|
889 |
+
"liberalising": "liberalizing",
|
890 |
+
"licence": "license",
|
891 |
+
"licenced": "licensed",
|
892 |
+
"licences": "licenses",
|
893 |
+
"licencing": "licensing",
|
894 |
+
"likeable": "likable",
|
895 |
+
"lionisation": "lionization",
|
896 |
+
"lionise": "lionize",
|
897 |
+
"lionised": "lionized",
|
898 |
+
"lionises": "lionizes",
|
899 |
+
"lionising": "lionizing",
|
900 |
+
"liquidise": "liquidize",
|
901 |
+
"liquidised": "liquidized",
|
902 |
+
"liquidiser": "liquidizer",
|
903 |
+
"liquidisers": "liquidizers",
|
904 |
+
"liquidises": "liquidizes",
|
905 |
+
"liquidising": "liquidizing",
|
906 |
+
"litre": "liter",
|
907 |
+
"litres": "liters",
|
908 |
+
"localise": "localize",
|
909 |
+
"localised": "localized",
|
910 |
+
"localises": "localizes",
|
911 |
+
"localising": "localizing",
|
912 |
+
"louvre": "louver",
|
913 |
+
"louvred": "louvered",
|
914 |
+
"louvres": "louvers",
|
915 |
+
"lustre": "luster",
|
916 |
+
"magnetise": "magnetize",
|
917 |
+
"magnetised": "magnetized",
|
918 |
+
"magnetises": "magnetizes",
|
919 |
+
"magnetising": "magnetizing",
|
920 |
+
"manoeuvrability": "maneuverability",
|
921 |
+
"manoeuvrable": "maneuverable",
|
922 |
+
"manoeuvre": "maneuver",
|
923 |
+
"manoeuvred": "maneuvered",
|
924 |
+
"manoeuvres": "maneuvers",
|
925 |
+
"manoeuvring": "maneuvering",
|
926 |
+
"manoeuvrings": "maneuverings",
|
927 |
+
"marginalisation": "marginalization",
|
928 |
+
"marginalise": "marginalize",
|
929 |
+
"marginalised": "marginalized",
|
930 |
+
"marginalises": "marginalizes",
|
931 |
+
"marginalising": "marginalizing",
|
932 |
+
"marshalled": "marshaled",
|
933 |
+
"marshalling": "marshaling",
|
934 |
+
"marvelled": "marveled",
|
935 |
+
"marvelling": "marveling",
|
936 |
+
"marvellous": "marvelous",
|
937 |
+
"marvellously": "marvelously",
|
938 |
+
"materialisation": "materialization",
|
939 |
+
"materialise": "materialize",
|
940 |
+
"materialised": "materialized",
|
941 |
+
"materialises": "materializes",
|
942 |
+
"materialising": "materializing",
|
943 |
+
"maximisation": "maximization",
|
944 |
+
"maximise": "maximize",
|
945 |
+
"maximised": "maximized",
|
946 |
+
"maximises": "maximizes",
|
947 |
+
"maximising": "maximizing",
|
948 |
+
"meagre": "meager",
|
949 |
+
"mechanisation": "mechanization",
|
950 |
+
"mechanise": "mechanize",
|
951 |
+
"mechanised": "mechanized",
|
952 |
+
"mechanises": "mechanizes",
|
953 |
+
"mechanising": "mechanizing",
|
954 |
+
"mediaeval": "medieval",
|
955 |
+
"memorialise": "memorialize",
|
956 |
+
"memorialised": "memorialized",
|
957 |
+
"memorialises": "memorializes",
|
958 |
+
"memorialising": "memorializing",
|
959 |
+
"memorise": "memorize",
|
960 |
+
"memorised": "memorized",
|
961 |
+
"memorises": "memorizes",
|
962 |
+
"memorising": "memorizing",
|
963 |
+
"mesmerise": "mesmerize",
|
964 |
+
"mesmerised": "mesmerized",
|
965 |
+
"mesmerises": "mesmerizes",
|
966 |
+
"mesmerising": "mesmerizing",
|
967 |
+
"metabolise": "metabolize",
|
968 |
+
"metabolised": "metabolized",
|
969 |
+
"metabolises": "metabolizes",
|
970 |
+
"metabolising": "metabolizing",
|
971 |
+
"metre": "meter",
|
972 |
+
"metres": "meters",
|
973 |
+
"mhm": "hmm",
|
974 |
+
"micrometre": "micrometer",
|
975 |
+
"micrometres": "micrometers",
|
976 |
+
"militarise": "militarize",
|
977 |
+
"militarised": "militarized",
|
978 |
+
"militarises": "militarizes",
|
979 |
+
"militarising": "militarizing",
|
980 |
+
"milligramme": "milligram",
|
981 |
+
"milligrammes": "milligrams",
|
982 |
+
"millilitre": "milliliter",
|
983 |
+
"millilitres": "milliliters",
|
984 |
+
"millimetre": "millimeter",
|
985 |
+
"millimetres": "millimeters",
|
986 |
+
"miniaturisation": "miniaturization",
|
987 |
+
"miniaturise": "miniaturize",
|
988 |
+
"miniaturised": "miniaturized",
|
989 |
+
"miniaturises": "miniaturizes",
|
990 |
+
"miniaturising": "miniaturizing",
|
991 |
+
"minibusses": "minibuses",
|
992 |
+
"minimise": "minimize",
|
993 |
+
"minimised": "minimized",
|
994 |
+
"minimises": "minimizes",
|
995 |
+
"minimising": "minimizing",
|
996 |
+
"misbehaviour": "misbehavior",
|
997 |
+
"misdemeanour": "misdemeanor",
|
998 |
+
"misdemeanours": "misdemeanors",
|
999 |
+
"misspelt": "misspelled",
|
1000 |
+
"mitre": "miter",
|
1001 |
+
"mitres": "miters",
|
1002 |
+
"mm": "hmm",
|
1003 |
+
"mmm": "hmm",
|
1004 |
+
"mobilisation": "mobilization",
|
1005 |
+
"mobilise": "mobilize",
|
1006 |
+
"mobilised": "mobilized",
|
1007 |
+
"mobilises": "mobilizes",
|
1008 |
+
"mobilising": "mobilizing",
|
1009 |
+
"modelled": "modeled",
|
1010 |
+
"modeller": "modeler",
|
1011 |
+
"modellers": "modelers",
|
1012 |
+
"modelling": "modeling",
|
1013 |
+
"modernise": "modernize",
|
1014 |
+
"modernised": "modernized",
|
1015 |
+
"modernises": "modernizes",
|
1016 |
+
"modernising": "modernizing",
|
1017 |
+
"moisturise": "moisturize",
|
1018 |
+
"moisturised": "moisturized",
|
1019 |
+
"moisturiser": "moisturizer",
|
1020 |
+
"moisturisers": "moisturizers",
|
1021 |
+
"moisturises": "moisturizes",
|
1022 |
+
"moisturising": "moisturizing",
|
1023 |
+
"monologue": "monolog",
|
1024 |
+
"monologues": "monologs",
|
1025 |
+
"monopolisation": "monopolization",
|
1026 |
+
"monopolise": "monopolize",
|
1027 |
+
"monopolised": "monopolized",
|
1028 |
+
"monopolises": "monopolizes",
|
1029 |
+
"monopolising": "monopolizing",
|
1030 |
+
"moralise": "moralize",
|
1031 |
+
"moralised": "moralized",
|
1032 |
+
"moralises": "moralizes",
|
1033 |
+
"moralising": "moralizing",
|
1034 |
+
"motorised": "motorized",
|
1035 |
+
"mould": "mold",
|
1036 |
+
"moulded": "molded",
|
1037 |
+
"moulder": "molder",
|
1038 |
+
"mouldered": "moldered",
|
1039 |
+
"mouldering": "moldering",
|
1040 |
+
"moulders": "molders",
|
1041 |
+
"mouldier": "moldier",
|
1042 |
+
"mouldiest": "moldiest",
|
1043 |
+
"moulding": "molding",
|
1044 |
+
"mouldings": "moldings",
|
1045 |
+
"moulds": "molds",
|
1046 |
+
"mouldy": "moldy",
|
1047 |
+
"moult": "molt",
|
1048 |
+
"moulted": "molted",
|
1049 |
+
"moulting": "molting",
|
1050 |
+
"moults": "molts",
|
1051 |
+
"moustache": "mustache",
|
1052 |
+
"moustached": "mustached",
|
1053 |
+
"moustaches": "mustaches",
|
1054 |
+
"moustachioed": "mustachioed",
|
1055 |
+
"multicoloured": "multicolored",
|
1056 |
+
"nationalisation": "nationalization",
|
1057 |
+
"nationalisations": "nationalizations",
|
1058 |
+
"nationalise": "nationalize",
|
1059 |
+
"nationalised": "nationalized",
|
1060 |
+
"nationalises": "nationalizes",
|
1061 |
+
"nationalising": "nationalizing",
|
1062 |
+
"naturalisation": "naturalization",
|
1063 |
+
"naturalise": "naturalize",
|
1064 |
+
"naturalised": "naturalized",
|
1065 |
+
"naturalises": "naturalizes",
|
1066 |
+
"naturalising": "naturalizing",
|
1067 |
+
"neighbour": "neighbor",
|
1068 |
+
"neighbourhood": "neighborhood",
|
1069 |
+
"neighbourhoods": "neighborhoods",
|
1070 |
+
"neighbouring": "neighboring",
|
1071 |
+
"neighbourliness": "neighborliness",
|
1072 |
+
"neighbourly": "neighborly",
|
1073 |
+
"neighbours": "neighbors",
|
1074 |
+
"neutralisation": "neutralization",
|
1075 |
+
"neutralise": "neutralize",
|
1076 |
+
"neutralised": "neutralized",
|
1077 |
+
"neutralises": "neutralizes",
|
1078 |
+
"neutralising": "neutralizing",
|
1079 |
+
"normalisation": "normalization",
|
1080 |
+
"normalise": "normalize",
|
1081 |
+
"normalised": "normalized",
|
1082 |
+
"normalises": "normalizes",
|
1083 |
+
"normalising": "normalizing",
|
1084 |
+
"odour": "odor",
|
1085 |
+
"odourless": "odorless",
|
1086 |
+
"odours": "odors",
|
1087 |
+
"oesophagus": "esophagus",
|
1088 |
+
"oesophaguses": "esophaguses",
|
1089 |
+
"oestrogen": "estrogen",
|
1090 |
+
"offence": "offense",
|
1091 |
+
"offences": "offenses",
|
1092 |
+
"omelette": "omelet",
|
1093 |
+
"omelettes": "omelets",
|
1094 |
+
"optimise": "optimize",
|
1095 |
+
"optimised": "optimized",
|
1096 |
+
"optimises": "optimizes",
|
1097 |
+
"optimising": "optimizing",
|
1098 |
+
"organisation": "organization",
|
1099 |
+
"organisational": "organizational",
|
1100 |
+
"organisations": "organizations",
|
1101 |
+
"organise": "organize",
|
1102 |
+
"organised": "organized",
|
1103 |
+
"organiser": "organizer",
|
1104 |
+
"organisers": "organizers",
|
1105 |
+
"organises": "organizes",
|
1106 |
+
"organising": "organizing",
|
1107 |
+
"orthopaedic": "orthopedic",
|
1108 |
+
"orthopaedics": "orthopedics",
|
1109 |
+
"ostracise": "ostracize",
|
1110 |
+
"ostracised": "ostracized",
|
1111 |
+
"ostracises": "ostracizes",
|
1112 |
+
"ostracising": "ostracizing",
|
1113 |
+
"outmanoeuvre": "outmaneuver",
|
1114 |
+
"outmanoeuvred": "outmaneuvered",
|
1115 |
+
"outmanoeuvres": "outmaneuvers",
|
1116 |
+
"outmanoeuvring": "outmaneuvering",
|
1117 |
+
"overemphasise": "overemphasize",
|
1118 |
+
"overemphasised": "overemphasized",
|
1119 |
+
"overemphasises": "overemphasizes",
|
1120 |
+
"overemphasising": "overemphasizing",
|
1121 |
+
"oxidisation": "oxidization",
|
1122 |
+
"oxidise": "oxidize",
|
1123 |
+
"oxidised": "oxidized",
|
1124 |
+
"oxidises": "oxidizes",
|
1125 |
+
"oxidising": "oxidizing",
|
1126 |
+
"paederast": "pederast",
|
1127 |
+
"paederasts": "pederasts",
|
1128 |
+
"paediatric": "pediatric",
|
1129 |
+
"paediatrician": "pediatrician",
|
1130 |
+
"paediatricians": "pediatricians",
|
1131 |
+
"paediatrics": "pediatrics",
|
1132 |
+
"paedophile": "pedophile",
|
1133 |
+
"paedophiles": "pedophiles",
|
1134 |
+
"paedophilia": "pedophilia",
|
1135 |
+
"palaeolithic": "paleolithic",
|
1136 |
+
"palaeontologist": "paleontologist",
|
1137 |
+
"palaeontologists": "paleontologists",
|
1138 |
+
"palaeontology": "paleontology",
|
1139 |
+
"panelled": "paneled",
|
1140 |
+
"panelling": "paneling",
|
1141 |
+
"panellist": "panelist",
|
1142 |
+
"panellists": "panelists",
|
1143 |
+
"paralyse": "paralyze",
|
1144 |
+
"paralysed": "paralyzed",
|
1145 |
+
"paralyses": "paralyzes",
|
1146 |
+
"paralysing": "paralyzing",
|
1147 |
+
"parcelled": "parceled",
|
1148 |
+
"parcelling": "parceling",
|
1149 |
+
"parlour": "parlor",
|
1150 |
+
"parlours": "parlors",
|
1151 |
+
"particularise": "particularize",
|
1152 |
+
"particularised": "particularized",
|
1153 |
+
"particularises": "particularizes",
|
1154 |
+
"particularising": "particularizing",
|
1155 |
+
"passivisation": "passivization",
|
1156 |
+
"passivise": "passivize",
|
1157 |
+
"passivised": "passivized",
|
1158 |
+
"passivises": "passivizes",
|
1159 |
+
"passivising": "passivizing",
|
1160 |
+
"pasteurisation": "pasteurization",
|
1161 |
+
"pasteurise": "pasteurize",
|
1162 |
+
"pasteurised": "pasteurized",
|
1163 |
+
"pasteurises": "pasteurizes",
|
1164 |
+
"pasteurising": "pasteurizing",
|
1165 |
+
"patronise": "patronize",
|
1166 |
+
"patronised": "patronized",
|
1167 |
+
"patronises": "patronizes",
|
1168 |
+
"patronising": "patronizing",
|
1169 |
+
"patronisingly": "patronizingly",
|
1170 |
+
"pedalled": "pedaled",
|
1171 |
+
"pedalling": "pedaling",
|
1172 |
+
"pedestrianisation": "pedestrianization",
|
1173 |
+
"pedestrianise": "pedestrianize",
|
1174 |
+
"pedestrianised": "pedestrianized",
|
1175 |
+
"pedestrianises": "pedestrianizes",
|
1176 |
+
"pedestrianising": "pedestrianizing",
|
1177 |
+
"penalise": "penalize",
|
1178 |
+
"penalised": "penalized",
|
1179 |
+
"penalises": "penalizes",
|
1180 |
+
"penalising": "penalizing",
|
1181 |
+
"pencilled": "penciled",
|
1182 |
+
"pencilling": "penciling",
|
1183 |
+
"personalise": "personalize",
|
1184 |
+
"personalised": "personalized",
|
1185 |
+
"personalises": "personalizes",
|
1186 |
+
"personalising": "personalizing",
|
1187 |
+
"pharmacopoeia": "pharmacopeia",
|
1188 |
+
"pharmacopoeias": "pharmacopeias",
|
1189 |
+
"philosophise": "philosophize",
|
1190 |
+
"philosophised": "philosophized",
|
1191 |
+
"philosophises": "philosophizes",
|
1192 |
+
"philosophising": "philosophizing",
|
1193 |
+
"philtre": "filter",
|
1194 |
+
"philtres": "filters",
|
1195 |
+
"phoney": "phony",
|
1196 |
+
"plagiarise": "plagiarize",
|
1197 |
+
"plagiarised": "plagiarized",
|
1198 |
+
"plagiarises": "plagiarizes",
|
1199 |
+
"plagiarising": "plagiarizing",
|
1200 |
+
"plough": "plow",
|
1201 |
+
"ploughed": "plowed",
|
1202 |
+
"ploughing": "plowing",
|
1203 |
+
"ploughman": "plowman",
|
1204 |
+
"ploughmen": "plowmen",
|
1205 |
+
"ploughs": "plows",
|
1206 |
+
"ploughshare": "plowshare",
|
1207 |
+
"ploughshares": "plowshares",
|
1208 |
+
"polarisation": "polarization",
|
1209 |
+
"polarise": "polarize",
|
1210 |
+
"polarised": "polarized",
|
1211 |
+
"polarises": "polarizes",
|
1212 |
+
"polarising": "polarizing",
|
1213 |
+
"politicisation": "politicization",
|
1214 |
+
"politicise": "politicize",
|
1215 |
+
"politicised": "politicized",
|
1216 |
+
"politicises": "politicizes",
|
1217 |
+
"politicising": "politicizing",
|
1218 |
+
"popularisation": "popularization",
|
1219 |
+
"popularise": "popularize",
|
1220 |
+
"popularised": "popularized",
|
1221 |
+
"popularises": "popularizes",
|
1222 |
+
"popularising": "popularizing",
|
1223 |
+
"pouffe": "pouf",
|
1224 |
+
"pouffes": "poufs",
|
1225 |
+
"practise": "practice",
|
1226 |
+
"practised": "practiced",
|
1227 |
+
"practises": "practices",
|
1228 |
+
"practising": "practicing",
|
1229 |
+
"praesidium": "presidium",
|
1230 |
+
"praesidiums": "presidiums",
|
1231 |
+
"pressurisation": "pressurization",
|
1232 |
+
"pressurise": "pressurize",
|
1233 |
+
"pressurised": "pressurized",
|
1234 |
+
"pressurises": "pressurizes",
|
1235 |
+
"pressurising": "pressurizing",
|
1236 |
+
"pretence": "pretense",
|
1237 |
+
"pretences": "pretenses",
|
1238 |
+
"primaeval": "primeval",
|
1239 |
+
"prioritisation": "prioritization",
|
1240 |
+
"prioritise": "prioritize",
|
1241 |
+
"prioritised": "prioritized",
|
1242 |
+
"prioritises": "prioritizes",
|
1243 |
+
"prioritising": "prioritizing",
|
1244 |
+
"privatisation": "privatization",
|
1245 |
+
"privatisations": "privatizations",
|
1246 |
+
"privatise": "privatize",
|
1247 |
+
"privatised": "privatized",
|
1248 |
+
"privatises": "privatizes",
|
1249 |
+
"privatising": "privatizing",
|
1250 |
+
"professionalisation": "professionalization",
|
1251 |
+
"professionalise": "professionalize",
|
1252 |
+
"professionalised": "professionalized",
|
1253 |
+
"professionalises": "professionalizes",
|
1254 |
+
"professionalising": "professionalizing",
|
1255 |
+
"programme": "program",
|
1256 |
+
"programmes": "programs",
|
1257 |
+
"prologue": "prolog",
|
1258 |
+
"prologues": "prologs",
|
1259 |
+
"propagandise": "propagandize",
|
1260 |
+
"propagandised": "propagandized",
|
1261 |
+
"propagandises": "propagandizes",
|
1262 |
+
"propagandising": "propagandizing",
|
1263 |
+
"proselytise": "proselytize",
|
1264 |
+
"proselytised": "proselytized",
|
1265 |
+
"proselytiser": "proselytizer",
|
1266 |
+
"proselytisers": "proselytizers",
|
1267 |
+
"proselytises": "proselytizes",
|
1268 |
+
"proselytising": "proselytizing",
|
1269 |
+
"psychoanalyse": "psychoanalyze",
|
1270 |
+
"psychoanalysed": "psychoanalyzed",
|
1271 |
+
"psychoanalyses": "psychoanalyzes",
|
1272 |
+
"psychoanalysing": "psychoanalyzing",
|
1273 |
+
"publicise": "publicize",
|
1274 |
+
"publicised": "publicized",
|
1275 |
+
"publicises": "publicizes",
|
1276 |
+
"publicising": "publicizing",
|
1277 |
+
"pulverisation": "pulverization",
|
1278 |
+
"pulverise": "pulverize",
|
1279 |
+
"pulverised": "pulverized",
|
1280 |
+
"pulverises": "pulverizes",
|
1281 |
+
"pulverising": "pulverizing",
|
1282 |
+
"pummelled": "pummel",
|
1283 |
+
"pummelling": "pummeled",
|
1284 |
+
"pyjama": "pajama",
|
1285 |
+
"pyjamas": "pajamas",
|
1286 |
+
"pzazz": "pizzazz",
|
1287 |
+
"quarrelled": "quarreled",
|
1288 |
+
"quarrelling": "quarreling",
|
1289 |
+
"radicalise": "radicalize",
|
1290 |
+
"radicalised": "radicalized",
|
1291 |
+
"radicalises": "radicalizes",
|
1292 |
+
"radicalising": "radicalizing",
|
1293 |
+
"rancour": "rancor",
|
1294 |
+
"randomise": "randomize",
|
1295 |
+
"randomised": "randomized",
|
1296 |
+
"randomises": "randomizes",
|
1297 |
+
"randomising": "randomizing",
|
1298 |
+
"rationalisation": "rationalization",
|
1299 |
+
"rationalisations": "rationalizations",
|
1300 |
+
"rationalise": "rationalize",
|
1301 |
+
"rationalised": "rationalized",
|
1302 |
+
"rationalises": "rationalizes",
|
1303 |
+
"rationalising": "rationalizing",
|
1304 |
+
"ravelled": "raveled",
|
1305 |
+
"ravelling": "raveling",
|
1306 |
+
"realisable": "realizable",
|
1307 |
+
"realisation": "realization",
|
1308 |
+
"realisations": "realizations",
|
1309 |
+
"realise": "realize",
|
1310 |
+
"realised": "realized",
|
1311 |
+
"realises": "realizes",
|
1312 |
+
"realising": "realizing",
|
1313 |
+
"recognisable": "recognizable",
|
1314 |
+
"recognisably": "recognizably",
|
1315 |
+
"recognisance": "recognizance",
|
1316 |
+
"recognise": "recognize",
|
1317 |
+
"recognised": "recognized",
|
1318 |
+
"recognises": "recognizes",
|
1319 |
+
"recognising": "recognizing",
|
1320 |
+
"reconnoitre": "reconnoiter",
|
1321 |
+
"reconnoitred": "reconnoitered",
|
1322 |
+
"reconnoitres": "reconnoiters",
|
1323 |
+
"reconnoitring": "reconnoitering",
|
1324 |
+
"refuelled": "refueled",
|
1325 |
+
"refuelling": "refueling",
|
1326 |
+
"regularisation": "regularization",
|
1327 |
+
"regularise": "regularize",
|
1328 |
+
"regularised": "regularized",
|
1329 |
+
"regularises": "regularizes",
|
1330 |
+
"regularising": "regularizing",
|
1331 |
+
"remodelled": "remodeled",
|
1332 |
+
"remodelling": "remodeling",
|
1333 |
+
"remould": "remold",
|
1334 |
+
"remoulded": "remolded",
|
1335 |
+
"remoulding": "remolding",
|
1336 |
+
"remoulds": "remolds",
|
1337 |
+
"reorganisation": "reorganization",
|
1338 |
+
"reorganisations": "reorganizations",
|
1339 |
+
"reorganise": "reorganize",
|
1340 |
+
"reorganised": "reorganized",
|
1341 |
+
"reorganises": "reorganizes",
|
1342 |
+
"reorganising": "reorganizing",
|
1343 |
+
"revelled": "reveled",
|
1344 |
+
"reveller": "reveler",
|
1345 |
+
"revellers": "revelers",
|
1346 |
+
"revelling": "reveling",
|
1347 |
+
"revitalise": "revitalize",
|
1348 |
+
"revitalised": "revitalized",
|
1349 |
+
"revitalises": "revitalizes",
|
1350 |
+
"revitalising": "revitalizing",
|
1351 |
+
"revolutionise": "revolutionize",
|
1352 |
+
"revolutionised": "revolutionized",
|
1353 |
+
"revolutionises": "revolutionizes",
|
1354 |
+
"revolutionising": "revolutionizing",
|
1355 |
+
"rhapsodise": "rhapsodize",
|
1356 |
+
"rhapsodised": "rhapsodized",
|
1357 |
+
"rhapsodises": "rhapsodizes",
|
1358 |
+
"rhapsodising": "rhapsodizing",
|
1359 |
+
"rigour": "rigor",
|
1360 |
+
"rigours": "rigors",
|
1361 |
+
"ritualised": "ritualized",
|
1362 |
+
"rivalled": "rivaled",
|
1363 |
+
"rivalling": "rivaling",
|
1364 |
+
"romanticise": "romanticize",
|
1365 |
+
"romanticised": "romanticized",
|
1366 |
+
"romanticises": "romanticizes",
|
1367 |
+
"romanticising": "romanticizing",
|
1368 |
+
"rumour": "rumor",
|
1369 |
+
"rumoured": "rumored",
|
1370 |
+
"rumours": "rumors",
|
1371 |
+
"sabre": "saber",
|
1372 |
+
"sabres": "sabers",
|
1373 |
+
"saltpetre": "saltpeter",
|
1374 |
+
"sanitise": "sanitize",
|
1375 |
+
"sanitised": "sanitized",
|
1376 |
+
"sanitises": "sanitizes",
|
1377 |
+
"sanitising": "sanitizing",
|
1378 |
+
"satirise": "satirize",
|
1379 |
+
"satirised": "satirized",
|
1380 |
+
"satirises": "satirizes",
|
1381 |
+
"satirising": "satirizing",
|
1382 |
+
"saviour": "savior",
|
1383 |
+
"saviours": "saviors",
|
1384 |
+
"savour": "savor",
|
1385 |
+
"savoured": "savored",
|
1386 |
+
"savouries": "savories",
|
1387 |
+
"savouring": "savoring",
|
1388 |
+
"savours": "savors",
|
1389 |
+
"savoury": "savory",
|
1390 |
+
"scandalise": "scandalize",
|
1391 |
+
"scandalised": "scandalized",
|
1392 |
+
"scandalises": "scandalizes",
|
1393 |
+
"scandalising": "scandalizing",
|
1394 |
+
"sceptic": "skeptic",
|
1395 |
+
"sceptical": "skeptical",
|
1396 |
+
"sceptically": "skeptically",
|
1397 |
+
"scepticism": "skepticism",
|
1398 |
+
"sceptics": "skeptics",
|
1399 |
+
"sceptre": "scepter",
|
1400 |
+
"sceptres": "scepters",
|
1401 |
+
"scrutinise": "scrutinize",
|
1402 |
+
"scrutinised": "scrutinized",
|
1403 |
+
"scrutinises": "scrutinizes",
|
1404 |
+
"scrutinising": "scrutinizing",
|
1405 |
+
"secularisation": "secularization",
|
1406 |
+
"secularise": "secularize",
|
1407 |
+
"secularised": "secularized",
|
1408 |
+
"secularises": "secularizes",
|
1409 |
+
"secularising": "secularizing",
|
1410 |
+
"sensationalise": "sensationalize",
|
1411 |
+
"sensationalised": "sensationalized",
|
1412 |
+
"sensationalises": "sensationalizes",
|
1413 |
+
"sensationalising": "sensationalizing",
|
1414 |
+
"sensitise": "sensitize",
|
1415 |
+
"sensitised": "sensitized",
|
1416 |
+
"sensitises": "sensitizes",
|
1417 |
+
"sensitising": "sensitizing",
|
1418 |
+
"sentimentalise": "sentimentalize",
|
1419 |
+
"sentimentalised": "sentimentalized",
|
1420 |
+
"sentimentalises": "sentimentalizes",
|
1421 |
+
"sentimentalising": "sentimentalizing",
|
1422 |
+
"sepulchre": "sepulcher",
|
1423 |
+
"sepulchres": "sepulchers",
|
1424 |
+
"serialisation": "serialization",
|
1425 |
+
"serialisations": "serializations",
|
1426 |
+
"serialise": "serialize",
|
1427 |
+
"serialised": "serialized",
|
1428 |
+
"serialises": "serializes",
|
1429 |
+
"serialising": "serializing",
|
1430 |
+
"sermonise": "sermonize",
|
1431 |
+
"sermonised": "sermonized",
|
1432 |
+
"sermonises": "sermonizes",
|
1433 |
+
"sermonising": "sermonizing",
|
1434 |
+
"sheikh": "sheik",
|
1435 |
+
"shovelled": "shoveled",
|
1436 |
+
"shovelling": "shoveling",
|
1437 |
+
"shrivelled": "shriveled",
|
1438 |
+
"shrivelling": "shriveling",
|
1439 |
+
"signalise": "signalize",
|
1440 |
+
"signalised": "signalized",
|
1441 |
+
"signalises": "signalizes",
|
1442 |
+
"signalising": "signalizing",
|
1443 |
+
"signalled": "signaled",
|
1444 |
+
"signalling": "signaling",
|
1445 |
+
"smoulder": "smolder",
|
1446 |
+
"smouldered": "smoldered",
|
1447 |
+
"smouldering": "smoldering",
|
1448 |
+
"smoulders": "smolders",
|
1449 |
+
"snivelled": "sniveled",
|
1450 |
+
"snivelling": "sniveling",
|
1451 |
+
"snorkelled": "snorkeled",
|
1452 |
+
"snorkelling": "snorkeling",
|
1453 |
+
"snowplough": "snowplow",
|
1454 |
+
"snowploughs": "snowplow",
|
1455 |
+
"socialisation": "socialization",
|
1456 |
+
"socialise": "socialize",
|
1457 |
+
"socialised": "socialized",
|
1458 |
+
"socialises": "socializes",
|
1459 |
+
"socialising": "socializing",
|
1460 |
+
"sodomise": "sodomize",
|
1461 |
+
"sodomised": "sodomized",
|
1462 |
+
"sodomises": "sodomizes",
|
1463 |
+
"sodomising": "sodomizing",
|
1464 |
+
"solemnise": "solemnize",
|
1465 |
+
"solemnised": "solemnized",
|
1466 |
+
"solemnises": "solemnizes",
|
1467 |
+
"solemnising": "solemnizing",
|
1468 |
+
"sombre": "somber",
|
1469 |
+
"specialisation": "specialization",
|
1470 |
+
"specialisations": "specializations",
|
1471 |
+
"specialise": "specialize",
|
1472 |
+
"specialised": "specialized",
|
1473 |
+
"specialises": "specializes",
|
1474 |
+
"specialising": "specializing",
|
1475 |
+
"spectre": "specter",
|
1476 |
+
"spectres": "specters",
|
1477 |
+
"spiralled": "spiraled",
|
1478 |
+
"spiralling": "spiraling",
|
1479 |
+
"splendour": "splendor",
|
1480 |
+
"splendours": "splendors",
|
1481 |
+
"squirrelled": "squirreled",
|
1482 |
+
"squirrelling": "squirreling",
|
1483 |
+
"stabilisation": "stabilization",
|
1484 |
+
"stabilise": "stabilize",
|
1485 |
+
"stabilised": "stabilized",
|
1486 |
+
"stabiliser": "stabilizer",
|
1487 |
+
"stabilisers": "stabilizers",
|
1488 |
+
"stabilises": "stabilizes",
|
1489 |
+
"stabilising": "stabilizing",
|
1490 |
+
"standardisation": "standardization",
|
1491 |
+
"standardise": "standardize",
|
1492 |
+
"standardised": "standardized",
|
1493 |
+
"standardises": "standardizes",
|
1494 |
+
"standardising": "standardizing",
|
1495 |
+
"stencilled": "stenciled",
|
1496 |
+
"stencilling": "stenciling",
|
1497 |
+
"sterilisation": "sterilization",
|
1498 |
+
"sterilisations": "sterilizations",
|
1499 |
+
"sterilise": "sterilize",
|
1500 |
+
"sterilised": "sterilized",
|
1501 |
+
"steriliser": "sterilizer",
|
1502 |
+
"sterilisers": "sterilizers",
|
1503 |
+
"sterilises": "sterilizes",
|
1504 |
+
"sterilising": "sterilizing",
|
1505 |
+
"stigmatisation": "stigmatization",
|
1506 |
+
"stigmatise": "stigmatize",
|
1507 |
+
"stigmatised": "stigmatized",
|
1508 |
+
"stigmatises": "stigmatizes",
|
1509 |
+
"stigmatising": "stigmatizing",
|
1510 |
+
"storey": "story",
|
1511 |
+
"storeys": "stories",
|
1512 |
+
"subsidisation": "subsidization",
|
1513 |
+
"subsidise": "subsidize",
|
1514 |
+
"subsidised": "subsidized",
|
1515 |
+
"subsidiser": "subsidizer",
|
1516 |
+
"subsidisers": "subsidizers",
|
1517 |
+
"subsidises": "subsidizes",
|
1518 |
+
"subsidising": "subsidizing",
|
1519 |
+
"succour": "succor",
|
1520 |
+
"succoured": "succored",
|
1521 |
+
"succouring": "succoring",
|
1522 |
+
"succours": "succors",
|
1523 |
+
"sulphate": "sulfate",
|
1524 |
+
"sulphates": "sulfates",
|
1525 |
+
"sulphide": "sulfide",
|
1526 |
+
"sulphides": "sulfides",
|
1527 |
+
"sulphur": "sulfur",
|
1528 |
+
"sulphurous": "sulfurous",
|
1529 |
+
"summarise": "summarize",
|
1530 |
+
"summarised": "summarized",
|
1531 |
+
"summarises": "summarizes",
|
1532 |
+
"summarising": "summarizing",
|
1533 |
+
"swivelled": "swiveled",
|
1534 |
+
"swivelling": "swiveling",
|
1535 |
+
"symbolise": "symbolize",
|
1536 |
+
"symbolised": "symbolized",
|
1537 |
+
"symbolises": "symbolizes",
|
1538 |
+
"symbolising": "symbolizing",
|
1539 |
+
"sympathise": "sympathize",
|
1540 |
+
"sympathised": "sympathized",
|
1541 |
+
"sympathiser": "sympathizer",
|
1542 |
+
"sympathisers": "sympathizers",
|
1543 |
+
"sympathises": "sympathizes",
|
1544 |
+
"sympathising": "sympathizing",
|
1545 |
+
"synchronisation": "synchronization",
|
1546 |
+
"synchronise": "synchronize",
|
1547 |
+
"synchronised": "synchronized",
|
1548 |
+
"synchronises": "synchronizes",
|
1549 |
+
"synchronising": "synchronizing",
|
1550 |
+
"synthesise": "synthesize",
|
1551 |
+
"synthesised": "synthesized",
|
1552 |
+
"synthesiser": "synthesizer",
|
1553 |
+
"synthesisers": "synthesizers",
|
1554 |
+
"synthesises": "synthesizes",
|
1555 |
+
"synthesising": "synthesizing",
|
1556 |
+
"syphon": "siphon",
|
1557 |
+
"syphoned": "siphoned",
|
1558 |
+
"syphoning": "siphoning",
|
1559 |
+
"syphons": "siphons",
|
1560 |
+
"systematisation": "systematization",
|
1561 |
+
"systematise": "systematize",
|
1562 |
+
"systematised": "systematized",
|
1563 |
+
"systematises": "systematizes",
|
1564 |
+
"systematising": "systematizing",
|
1565 |
+
"tantalise": "tantalize",
|
1566 |
+
"tantalised": "tantalized",
|
1567 |
+
"tantalises": "tantalizes",
|
1568 |
+
"tantalising": "tantalizing",
|
1569 |
+
"tantalisingly": "tantalizingly",
|
1570 |
+
"tasselled": "tasseled",
|
1571 |
+
"technicolour": "technicolor",
|
1572 |
+
"temporise": "temporize",
|
1573 |
+
"temporised": "temporized",
|
1574 |
+
"temporises": "temporizes",
|
1575 |
+
"temporising": "temporizing",
|
1576 |
+
"tenderise": "tenderize",
|
1577 |
+
"tenderised": "tenderized",
|
1578 |
+
"tenderises": "tenderizes",
|
1579 |
+
"tenderising": "tenderizing",
|
1580 |
+
"terrorise": "terrorize",
|
1581 |
+
"terrorised": "terrorized",
|
1582 |
+
"terrorises": "terrorizes",
|
1583 |
+
"terrorising": "terrorizing",
|
1584 |
+
"theatre": "theater",
|
1585 |
+
"theatregoer": "theatergoer",
|
1586 |
+
"theatregoers": "theatergoers",
|
1587 |
+
"theatres": "theaters",
|
1588 |
+
"theorise": "theorize",
|
1589 |
+
"theorised": "theorized",
|
1590 |
+
"theorises": "theorizes",
|
1591 |
+
"theorising": "theorizing",
|
1592 |
+
"tonne": "ton",
|
1593 |
+
"tonnes": "tons",
|
1594 |
+
"towelled": "toweled",
|
1595 |
+
"towelling": "toweling",
|
1596 |
+
"toxaemia": "toxemia",
|
1597 |
+
"tranquillise": "tranquilize",
|
1598 |
+
"tranquillised": "tranquilized",
|
1599 |
+
"tranquilliser": "tranquilizer",
|
1600 |
+
"tranquillisers": "tranquilizers",
|
1601 |
+
"tranquillises": "tranquilizes",
|
1602 |
+
"tranquillising": "tranquilizing",
|
1603 |
+
"tranquillity": "tranquility",
|
1604 |
+
"tranquillize": "tranquilize",
|
1605 |
+
"tranquillized": "tranquilized",
|
1606 |
+
"tranquillizer": "tranquilizer",
|
1607 |
+
"tranquillizers": "tranquilizers",
|
1608 |
+
"tranquillizes": "tranquilizes",
|
1609 |
+
"tranquillizing": "tranquilizing",
|
1610 |
+
"tranquilly": "tranquility",
|
1611 |
+
"transistorised": "transistorized",
|
1612 |
+
"traumatise": "traumatize",
|
1613 |
+
"traumatised": "traumatized",
|
1614 |
+
"traumatises": "traumatizes",
|
1615 |
+
"traumatising": "traumatizing",
|
1616 |
+
"travelled": "traveled",
|
1617 |
+
"traveller": "traveler",
|
1618 |
+
"travellers": "travelers",
|
1619 |
+
"travelling": "traveling",
|
1620 |
+
"travelog": "travelogue",
|
1621 |
+
"travelogs": "travelogues",
|
1622 |
+
"trialled": "trialed",
|
1623 |
+
"trialling": "trialing",
|
1624 |
+
"tricolour": "tricolor",
|
1625 |
+
"tricolours": "tricolors",
|
1626 |
+
"trivialise": "trivialize",
|
1627 |
+
"trivialised": "trivialized",
|
1628 |
+
"trivialises": "trivializes",
|
1629 |
+
"trivialising": "trivializing",
|
1630 |
+
"tumour": "tumor",
|
1631 |
+
"tumours": "tumors",
|
1632 |
+
"tunnelled": "tunneled",
|
1633 |
+
"tunnelling": "tunneling",
|
1634 |
+
"tyrannise": "tyrannize",
|
1635 |
+
"tyrannised": "tyrannized",
|
1636 |
+
"tyrannises": "tyrannizes",
|
1637 |
+
"tyrannising": "tyrannizing",
|
1638 |
+
"tyre": "tire",
|
1639 |
+
"tyres": "tires",
|
1640 |
+
"unauthorised": "unauthorized",
|
1641 |
+
"uncivilised": "uncivilized",
|
1642 |
+
"underutilised": "underutilized",
|
1643 |
+
"unequalled": "unequaled",
|
1644 |
+
"unfavourable": "unfavorable",
|
1645 |
+
"unfavourably": "unfavorably",
|
1646 |
+
"unionisation": "unionization",
|
1647 |
+
"unionise": "unionize",
|
1648 |
+
"unionised": "unionized",
|
1649 |
+
"unionises": "unionizes",
|
1650 |
+
"unionising": "unionizing",
|
1651 |
+
"unorganised": "unorganized",
|
1652 |
+
"unravelled": "unraveled",
|
1653 |
+
"unravelling": "unraveling",
|
1654 |
+
"unrecognisable": "unrecognizable",
|
1655 |
+
"unrecognised": "unrecognized",
|
1656 |
+
"unrivalled": "unrivaled",
|
1657 |
+
"unsavoury": "unsavory",
|
1658 |
+
"untrammelled": "untrammeled",
|
1659 |
+
"urbanisation": "urbanization",
|
1660 |
+
"urbanise": "urbanize",
|
1661 |
+
"urbanised": "urbanized",
|
1662 |
+
"urbanises": "urbanizes",
|
1663 |
+
"urbanising": "urbanizing",
|
1664 |
+
"utilisable": "utilizable",
|
1665 |
+
"utilisation": "utilization",
|
1666 |
+
"utilise": "utilize",
|
1667 |
+
"utilised": "utilized",
|
1668 |
+
"utilises": "utilizes",
|
1669 |
+
"utilising": "utilizing",
|
1670 |
+
"valour": "valor",
|
1671 |
+
"vandalise": "vandalize",
|
1672 |
+
"vandalised": "vandalized",
|
1673 |
+
"vandalises": "vandalizes",
|
1674 |
+
"vandalising": "vandalizing",
|
1675 |
+
"vaporisation": "vaporization",
|
1676 |
+
"vaporise": "vaporize",
|
1677 |
+
"vaporised": "vaporized",
|
1678 |
+
"vaporises": "vaporizes",
|
1679 |
+
"vaporising": "vaporizing",
|
1680 |
+
"vapour": "vapor",
|
1681 |
+
"vapours": "vapors",
|
1682 |
+
"verbalise": "verbalize",
|
1683 |
+
"verbalised": "verbalized",
|
1684 |
+
"verbalises": "verbalizes",
|
1685 |
+
"verbalising": "verbalizing",
|
1686 |
+
"victimisation": "victimization",
|
1687 |
+
"victimise": "victimize",
|
1688 |
+
"victimised": "victimized",
|
1689 |
+
"victimises": "victimizes",
|
1690 |
+
"victimising": "victimizing",
|
1691 |
+
"videodisc": "videodisk",
|
1692 |
+
"videodiscs": "videodisks",
|
1693 |
+
"vigour": "vigor",
|
1694 |
+
"visualisation": "visualization",
|
1695 |
+
"visualisations": "visualizations",
|
1696 |
+
"visualise": "visualize",
|
1697 |
+
"visualised": "visualized",
|
1698 |
+
"visualises": "visualizes",
|
1699 |
+
"visualising": "visualizing",
|
1700 |
+
"vocalisation": "vocalization",
|
1701 |
+
"vocalisations": "vocalizations",
|
1702 |
+
"vocalise": "vocalize",
|
1703 |
+
"vocalised": "vocalized",
|
1704 |
+
"vocalises": "vocalizes",
|
1705 |
+
"vocalising": "vocalizing",
|
1706 |
+
"vulcanised": "vulcanized",
|
1707 |
+
"vulgarisation": "vulgarization",
|
1708 |
+
"vulgarise": "vulgarize",
|
1709 |
+
"vulgarised": "vulgarized",
|
1710 |
+
"vulgarises": "vulgarizes",
|
1711 |
+
"vulgarising": "vulgarizing",
|
1712 |
+
"waggon": "wagon",
|
1713 |
+
"waggons": "wagons",
|
1714 |
+
"watercolour": "watercolor",
|
1715 |
+
"watercolours": "watercolors",
|
1716 |
+
"weaselled": "weaseled",
|
1717 |
+
"weaselling": "weaseling",
|
1718 |
+
"westernisation": "westernization",
|
1719 |
+
"westernise": "westernize",
|
1720 |
+
"westernised": "westernized",
|
1721 |
+
"westernises": "westernizes",
|
1722 |
+
"westernising": "westernizing",
|
1723 |
+
"womanise": "womanize",
|
1724 |
+
"womanised": "womanized",
|
1725 |
+
"womaniser": "womanizer",
|
1726 |
+
"womanisers": "womanizers",
|
1727 |
+
"womanises": "womanizes",
|
1728 |
+
"womanising": "womanizing",
|
1729 |
+
"woollen": "woolen",
|
1730 |
+
"woollens": "woolens",
|
1731 |
+
"woollies": "woolies",
|
1732 |
+
"woolly": "wooly",
|
1733 |
+
"worshipped": "worshiped",
|
1734 |
+
"worshipper": "worshiper",
|
1735 |
+
"worshipping": "worshiping",
|
1736 |
+
"yodelled": "yodeled",
|
1737 |
+
"yodelling": "yodeling",
|
1738 |
+
"yoghourt": "yogurt",
|
1739 |
+
"yoghourts": "yogurts",
|
1740 |
+
"yoghurt": "yogurt",
|
1741 |
+
"yoghurts": "yogurts"
|
1742 |
+
}
|
distil-large-v3-init/preprocessor_config.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"chunk_length": 30,
|
3 |
+
"feature_extractor_type": "WhisperFeatureExtractor",
|
4 |
+
"feature_size": 128,
|
5 |
+
"hop_length": 160,
|
6 |
+
"n_fft": 400,
|
7 |
+
"n_samples": 480000,
|
8 |
+
"nb_max_frames": 3000,
|
9 |
+
"padding_side": "right",
|
10 |
+
"padding_value": 0.0,
|
11 |
+
"processor_class": "WhisperProcessor",
|
12 |
+
"return_attention_mask": false,
|
13 |
+
"sampling_rate": 16000
|
14 |
+
}
|
distil-large-v3-init/special_tokens_map.json
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"additional_special_tokens": [
|
3 |
+
"<|startoftranscript|>",
|
4 |
+
"<|en|>",
|
5 |
+
"<|zh|>",
|
6 |
+
"<|de|>",
|
7 |
+
"<|es|>",
|
8 |
+
"<|ru|>",
|
9 |
+
"<|ko|>",
|
10 |
+
"<|fr|>",
|
11 |
+
"<|ja|>",
|
12 |
+
"<|pt|>",
|
13 |
+
"<|tr|>",
|
14 |
+
"<|pl|>",
|
15 |
+
"<|ca|>",
|
16 |
+
"<|nl|>",
|
17 |
+
"<|ar|>",
|
18 |
+
"<|sv|>",
|
19 |
+
"<|it|>",
|
20 |
+
"<|id|>",
|
21 |
+
"<|hi|>",
|
22 |
+
"<|fi|>",
|
23 |
+
"<|vi|>",
|
24 |
+
"<|he|>",
|
25 |
+
"<|uk|>",
|
26 |
+
"<|el|>",
|
27 |
+
"<|ms|>",
|
28 |
+
"<|cs|>",
|
29 |
+
"<|ro|>",
|
30 |
+
"<|da|>",
|
31 |
+
"<|hu|>",
|
32 |
+
"<|ta|>",
|
33 |
+
"<|no|>",
|
34 |
+
"<|th|>",
|
35 |
+
"<|ur|>",
|
36 |
+
"<|hr|>",
|
37 |
+
"<|bg|>",
|
38 |
+
"<|lt|>",
|
39 |
+
"<|la|>",
|
40 |
+
"<|mi|>",
|
41 |
+
"<|ml|>",
|
42 |
+
"<|cy|>",
|
43 |
+
"<|sk|>",
|
44 |
+
"<|te|>",
|
45 |
+
"<|fa|>",
|
46 |
+
"<|lv|>",
|
47 |
+
"<|bn|>",
|
48 |
+
"<|sr|>",
|
49 |
+
"<|az|>",
|
50 |
+
"<|sl|>",
|
51 |
+
"<|kn|>",
|
52 |
+
"<|et|>",
|
53 |
+
"<|mk|>",
|
54 |
+
"<|br|>",
|
55 |
+
"<|eu|>",
|
56 |
+
"<|is|>",
|
57 |
+
"<|hy|>",
|
58 |
+
"<|ne|>",
|
59 |
+
"<|mn|>",
|
60 |
+
"<|bs|>",
|
61 |
+
"<|kk|>",
|
62 |
+
"<|sq|>",
|
63 |
+
"<|sw|>",
|
64 |
+
"<|gl|>",
|
65 |
+
"<|mr|>",
|
66 |
+
"<|pa|>",
|
67 |
+
"<|si|>",
|
68 |
+
"<|km|>",
|
69 |
+
"<|sn|>",
|
70 |
+
"<|yo|>",
|
71 |
+
"<|so|>",
|
72 |
+
"<|af|>",
|
73 |
+
"<|oc|>",
|
74 |
+
"<|ka|>",
|
75 |
+
"<|be|>",
|
76 |
+
"<|tg|>",
|
77 |
+
"<|sd|>",
|
78 |
+
"<|gu|>",
|
79 |
+
"<|am|>",
|
80 |
+
"<|yi|>",
|
81 |
+
"<|lo|>",
|
82 |
+
"<|uz|>",
|
83 |
+
"<|fo|>",
|
84 |
+
"<|ht|>",
|
85 |
+
"<|ps|>",
|
86 |
+
"<|tk|>",
|
87 |
+
"<|nn|>",
|
88 |
+
"<|mt|>",
|
89 |
+
"<|sa|>",
|
90 |
+
"<|lb|>",
|
91 |
+
"<|my|>",
|
92 |
+
"<|bo|>",
|
93 |
+
"<|tl|>",
|
94 |
+
"<|mg|>",
|
95 |
+
"<|as|>",
|
96 |
+
"<|tt|>",
|
97 |
+
"<|haw|>",
|
98 |
+
"<|ln|>",
|
99 |
+
"<|ha|>",
|
100 |
+
"<|ba|>",
|
101 |
+
"<|jw|>",
|
102 |
+
"<|su|>",
|
103 |
+
"<|yue|>",
|
104 |
+
"<|translate|>",
|
105 |
+
"<|transcribe|>",
|
106 |
+
"<|startoflm|>",
|
107 |
+
"<|startofprev|>",
|
108 |
+
"<|nospeech|>",
|
109 |
+
"<|notimestamps|>"
|
110 |
+
],
|
111 |
+
"bos_token": {
|
112 |
+
"content": "<|endoftext|>",
|
113 |
+
"lstrip": false,
|
114 |
+
"normalized": false,
|
115 |
+
"rstrip": false,
|
116 |
+
"single_word": false
|
117 |
+
},
|
118 |
+
"eos_token": {
|
119 |
+
"content": "<|endoftext|>",
|
120 |
+
"lstrip": false,
|
121 |
+
"normalized": false,
|
122 |
+
"rstrip": false,
|
123 |
+
"single_word": false
|
124 |
+
},
|
125 |
+
"pad_token": {
|
126 |
+
"content": "<|endoftext|>",
|
127 |
+
"lstrip": false,
|
128 |
+
"normalized": false,
|
129 |
+
"rstrip": false,
|
130 |
+
"single_word": false
|
131 |
+
},
|
132 |
+
"unk_token": {
|
133 |
+
"content": "<|endoftext|>",
|
134 |
+
"lstrip": false,
|
135 |
+
"normalized": false,
|
136 |
+
"rstrip": false,
|
137 |
+
"single_word": false
|
138 |
+
}
|
139 |
+
}
|
distil-large-v3-init/tokenizer_config.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
distil-large-v3-init/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
distil_whisper.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,655 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: distil_whisper
|
3 |
+
Version: 0.0.0
|
4 |
+
Summary: Toolkit for distilling OpenAI's Whisper model.
|
5 |
+
Description-Content-Type: text/markdown
|
6 |
+
Requires-Dist: torch>=1.10
|
7 |
+
Requires-Dist: transformers>=4.35.1
|
8 |
+
Requires-Dist: datasets[audio]>=2.14.7
|
9 |
+
Requires-Dist: accelerate>=0.24.1
|
10 |
+
Requires-Dist: jiwer
|
11 |
+
Requires-Dist: evaluate>=0.4.1
|
12 |
+
Requires-Dist: wandb
|
13 |
+
Requires-Dist: tensorboard
|
14 |
+
Requires-Dist: nltk
|
15 |
+
Provides-Extra: dev
|
16 |
+
Requires-Dist: ruff==0.1.5; extra == "dev"
|
17 |
+
|
18 |
+
## Training Distil-Whisper
|
19 |
+
|
20 |
+
This sub-folder contains all the scripts required to train a Distil-Whisper model in your choice of language. They are
|
21 |
+
slightly modified from the original scripts used to distill Whisper for English ASR (as-per the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430)).
|
22 |
+
The main difference is that these scripts are written in [PyTorch](https://pytorch.org), whereas the original scripts
|
23 |
+
are in [JAX](https://jax.readthedocs.io/en/latest/#)/[Flax](https://flax.readthedocs.io/en/latest/). These scripts are
|
24 |
+
also made to be easier to run end-to-end, whereas the original scripts require more steps and are somewhat hard-coded
|
25 |
+
for English ASR. Both sets of scripts achieve equivalent downstream results when the hyper-parameters are set equal.
|
26 |
+
|
27 |
+
If you are interested in reproducing the original Distil-Whisper checkpoints, we refer you to the sub-folder [Flax Training](./flax/README.md).
|
28 |
+
Otherwise, if you wish to distill Whisper on your own language/dataset, we recommend you use these scripts for ease of use
|
29 |
+
and the configurability they provide.
|
30 |
+
|
31 |
+
Reproducing the Distil-Whisper project requires four stages to be completed in successive order:
|
32 |
+
|
33 |
+
1. [Pseudo-labelling](#1-pseudo-labelling)
|
34 |
+
2. [Initialisation](#2-initialisation)
|
35 |
+
3. [Training](#3-training)
|
36 |
+
4. [Evaluation](#4-evaluation)
|
37 |
+
|
38 |
+
This README is partitioned according to the four stages. Each section provides a minimal example for running the
|
39 |
+
scripts used in the project. We will use a running example of distilling the Whisper model for Hindi speech recognition
|
40 |
+
on the Common Voice dataset. Note that this dataset only contains ~20 hours of audio data. Thus, it can be run extremely
|
41 |
+
quickly, but does not provide sufficient data to achieve optimal performance. We recommend training on upwards of 1000
|
42 |
+
hours of data should you want to match the performance of Whisper on high-resource languages.
|
43 |
+
|
44 |
+
## Requirements
|
45 |
+
|
46 |
+
The Distil-Whisper training code is written in [PyTorch](https://pytorch.org) and [Accelerate](https://huggingface.co/docs/accelerate/index).
|
47 |
+
It heavily leverages the Whisper implementation in [🤗 Transformers](https://github.com/huggingface/transformers) for both
|
48 |
+
training and inference.
|
49 |
+
|
50 |
+
The instructions for installing the package are as follows:
|
51 |
+
1. Install PyTorch from the [official instructions](https://pytorch.org/get-started/locally/), ensuring you install the correct version for your hardware and CUDA version.
|
52 |
+
2. Fork the `distil-whisper` repository by clicking on the [fork](https://github.com/huggingface/distil-whisper/fork) button on the reopsitory's page
|
53 |
+
3. Clone the `distil-whisper` repository and add the base repository as a remote. This will allow you to "pull" any upstream changes that are made to the base repository:
|
54 |
+
|
55 |
+
```bash
|
56 |
+
git clone https://github.com/<your GitHub handle>/distil-whisper.git
|
57 |
+
cd distil-whisper
|
58 |
+
git remote add upstream https://github.com/huggingface/distil-whisper.git
|
59 |
+
```
|
60 |
+
4. pip install the required packages from the [setup.py](./setup.py) file:
|
61 |
+
```bash
|
62 |
+
cd training
|
63 |
+
pip install -e .
|
64 |
+
cd ../..
|
65 |
+
```
|
66 |
+
|
67 |
+
5. Configure Accelerate by running the following command. Note that you should set the number of GPUs you wish to use for distillation, and also the data type (dtype) to your preferred dtype for training/inference (e.g. `bfloat16` on A100 GPUs, `float16` on V100 GPUs, etc.):
|
68 |
+
|
69 |
+
```bash
|
70 |
+
accelerate config
|
71 |
+
```
|
72 |
+
|
73 |
+
6. The last thing we need to do is link our Hugging Face account so that we can pull/push model repositories on the Hub. This will allow us to save our final distilled weights on the Hub so that we can share them with the community. Run the command:
|
74 |
+
|
75 |
+
```bash
|
76 |
+
git config --global credential.helper store
|
77 |
+
huggingface-cli login
|
78 |
+
```
|
79 |
+
And then enter an authentication token from https://huggingface.co/settings/tokens. Create a new token if you do not have one already. You should make sure that this token has "write" privileges.
|
80 |
+
|
81 |
+
To confirm that you have a working environment, first accept the terms of use of the Common Voice 16.1 dataset on the Hub: https://huggingface.co/datasets/mozilla-foundation/common_voice_16_1
|
82 |
+
|
83 |
+
You can run the following code cell to stream one sample of data from the Common Voice dataset, and check that you can
|
84 |
+
perform inference using the "tiny" Whisper model:
|
85 |
+
|
86 |
+
```python
|
87 |
+
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
88 |
+
from datasets import load_dataset, Audio
|
89 |
+
|
90 |
+
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny", low_cpu_mem_usage=True)
|
91 |
+
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
92 |
+
|
93 |
+
model.to("cuda")
|
94 |
+
|
95 |
+
common_voice = load_dataset("mozilla-foundation/common_voice_16_1", "en", split="validation", streaming=True)
|
96 |
+
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
|
97 |
+
|
98 |
+
inputs = processor(next(iter(common_voice))["audio"]["array"], sampling_rate=16000, return_tensors="pt")
|
99 |
+
input_features = inputs.input_features
|
100 |
+
|
101 |
+
generated_ids = model.generate(input_features.to("cuda"), max_new_tokens=128)
|
102 |
+
pred_text = processor.decode(generated_ids[0], skip_special_tokens=True)
|
103 |
+
|
104 |
+
print("Pred text:", pred_text)
|
105 |
+
print("Environment set up successful?", generated_ids.shape[-1] == 20)
|
106 |
+
```
|
107 |
+
|
108 |
+
## 1. Pseudo-Labelling
|
109 |
+
|
110 |
+
The python script [`run_pseudo_labelling.py`](run_pseudo_labelling.py) is a flexible inference script that can be used
|
111 |
+
to generate pseudo-labels under a range of settings, including using both greedy and beam-search. It is also compatible
|
112 |
+
with [🤗 Datasets](https://github.com/huggingface/datasets) *streaming mode*, allowing users to load massive audio
|
113 |
+
datasets with **no disk space requirements**. For more information on streaming mode, the reader is referred to the
|
114 |
+
blog post: [A Complete Guide to Audio Datasets](https://huggingface.co/blog/audio-datasets#streaming-mode-the-silver-bullet).
|
115 |
+
|
116 |
+
> As of the latest Distil-Whisper release, [`distil-large-v3`](https://huggingface.co/distil-whisper/distil-large-v3), this
|
117 |
+
pseudo-labelling script also performs the added operation of concatenating (or packing) the audio inputs to 30-seconds.
|
118 |
+
Not only does this lead to a WER improvement when using sequential long-form decoding algorithm, but concatenating audios
|
119 |
+
to 30-seconds also improves the throughput during training, since the amount of zero-padding on the audio inputs is minimised.
|
120 |
+
|
121 |
+
The following script demonstrates how to pseudo-label the Hindi split of the Common Voice 16.1 dataset with greedy sampling:
|
122 |
+
|
123 |
+
```bash
|
124 |
+
#!/usr/bin/env bash
|
125 |
+
|
126 |
+
accelerate launch run_pseudo_labelling.py \
|
127 |
+
--model_name_or_path "openai/whisper-large-v3" \
|
128 |
+
--dataset_name "mozilla-foundation/common_voice_16_1" \
|
129 |
+
--dataset_config_name "hi" \
|
130 |
+
--dataset_split_name "train+validation+test" \
|
131 |
+
--text_column_name "sentence" \
|
132 |
+
--id_column_name "path" \
|
133 |
+
--output_dir "./common_voice_16_1_hi_pseudo_labelled" \
|
134 |
+
--wandb_project "distil-whisper-labelling" \
|
135 |
+
--per_device_eval_batch_size 64 \
|
136 |
+
--dtype "bfloat16" \
|
137 |
+
--attn_implementation "sdpa" \
|
138 |
+
--logging_steps 500 \
|
139 |
+
--max_label_length 256 \
|
140 |
+
--concatenate_audio \
|
141 |
+
--preprocessing_batch_size 500 \
|
142 |
+
--preprocessing_num_workers 8 \
|
143 |
+
--dataloader_num_workers 8 \
|
144 |
+
--report_to "wandb" \
|
145 |
+
--language "hi" \
|
146 |
+
--task "transcribe" \
|
147 |
+
--return_timestamps \
|
148 |
+
--streaming False \
|
149 |
+
--generation_num_beams 1 \
|
150 |
+
--push_to_hub
|
151 |
+
```
|
152 |
+
|
153 |
+
On an 80 GB A100 GPU, the following script takes approximately 5 minutes to concatenate and pre-process the 20 hours of
|
154 |
+
audio data, and a further 10 minutes to transcribe the pseudo-labels. The pseudo-labelled dataset corresponding to this
|
155 |
+
script is available on the Hugging Face Hub under [sanchit-gandhi/common_voice_16_1_hi_pseudo_labelled](https://huggingface.co/datasets/sanchit-gandhi/common_voice_16_1_hi_pseudo_labelled).
|
156 |
+
The WER of the pre-trained Whisper large-v3 model is 17.2% on the test split. We will compare the performance of our distilled model against this number.
|
157 |
+
|
158 |
+
There are two noteworthy arguments that configure the dataset concatenation (or packing) process:
|
159 |
+
1. `concatenate_audio`: whether or not to concatenate (or pack) the audios to 30-second chunks. The latest Distil-Whisper model, [`distil-large-v3`](https://huggingface.co/distil-whisper/distil-large-v3#differences-with-distil-large-v2), highlights the WER improvements obtained using the sequential long-form decoding algorithm when concatenated audios are used. Concatenating audios to 30-seconds also improves the throughput during training, since the amount of zero-padding on the audio inputs is minimised. Hence, it is highly recommended to set `--concatenate_audio=True`.
|
160 |
+
2. `preprocessing_batch_size`: the batch size to use when concatenating (or packing) the audios. Using a larger batch size results in a greater portion of audio samples being packed to 30-seconds, at the expense of higher memory consumption. If you exceed your system's RAM when performing the concatenation operation, reduce the `preprocessing_batch_size` by a factor of 2 to 250 or even 125.
|
161 |
+
3. `preprocessing_num_workers`: the number of multiprocessing workers to use when concatenating the audios. Using more workers will result in faster pre-processing, at the expense of higher memory consumption. Ensure you do not exceed the maximum number of CPUs on your device.
|
162 |
+
|
163 |
+
In addition, the following arguments configure the inference of the Whisper model:
|
164 |
+
1. `language`: explicitly setting the language token during inference substantially improves the generation performance of the Whisper model, since the model is forced always to predict in the given language. We recommend you set the language to the language you wish to distil the Whisper model on. The only exception is when distilling an English-only model (i.e. where the model id is appended with an `.en`, e.g. `small.en`), the language argument should be set to None, since there is no language token used during training/inference.
|
165 |
+
2. `return_timestamps`: whether or not to predict timestamps in the pseudo-labels. Timestamp prediction is required should you want your distilled model to be able to predict timestamps at inference time (e.g. for the original OpenAI long-form transcription algorithm). However, the pseudo-labels are marginally less accurate than not using timestamps. We recommend pseudo-labelling **with** timestamps to ensure the distilled model is as general as possible.
|
166 |
+
3. `attn_implementation`: which attention implementation to use for inference. Set to `sdpa` for [PyTorch SDPA](https://huggingface.co/docs/transformers/v4.35.2/en/perf_infer_gpu_one#bettertransformer), or `flash_attn_2` if your hardware supports Flash Attention 2 and you have the [package installed](https://github.com/Dao-AILab/flash-attention).
|
167 |
+
4. `streaming`: whether or not to use Datasets' streaming mode. If enabled, the audio data will be streamed from the Hugging Face Hub with no disk space requirements. However, the user is then responsible for adding the pseudo-labels to the dataset script in a follow-up step (see [Using Streaming Mode](#TODO)). If set to `False`, the audio data will be downloaded and pre-processed offline. At the end of pseudo-labelling, the pseudo-labels will be automatically appended to the original dataset, meaning the dataset is ready to be used for the subsequent training step without any additional steps.
|
168 |
+
5. `generation_num_beams`: how many beams to use while decoding. In practice, we found the distilled model to perform comparably when the data was pseudo-labelled with `generation_num_beams=1` (greedy) or `generation_num_beams>1` (beam). This is likely because the WER filter compensates for the lower quality pseudo-labels obtained using greedy search. However, using `generation_num_beams=1` gives substantially faster inference time for the pseudo-labelling step, and so we recommend this configuration.
|
169 |
+
|
170 |
+
Should you have your own audio dataset, you can first [convert it](https://huggingface.co/docs/datasets/audio_dataset) to
|
171 |
+
Hugging Face Datasets format and push it to the Hugging Face Hub. You can then pseudo-label it using the script above,
|
172 |
+
replacing the `--dataset_name` with the name of your dataset on the Hub.
|
173 |
+
|
174 |
+
Otherwise, you may wish to use an open-source dataset already available on the Hugging Face Hub. We provide a summary of
|
175 |
+
the three most popular multilingual datasets in the table below. For more details, refer to the blog post: [A Complete Guide to Audio Datasets](https://huggingface.co/blog/audio-datasets#multilingual-speech-recognition).
|
176 |
+
|
177 |
+
| Dataset | Languages | Domain | Speaking Style | License | Text Column | ID Column |
|
178 |
+
|-----------------------------------------------------------------------------------------------|-----------|---------------------------------------|----------------|-----------|---------------------|--------------|
|
179 |
+
| [Multilingual LibriSpeech](https://huggingface.co/datasets/facebook/multilingual_librispeech) | 6 | Audiobooks | Narrated | CC-BY-4.0 | `"text"` | `"id"` |
|
180 |
+
| [Common Voice 16](https://huggingface.co/datasets/mozilla-foundation/common_voice_16_1) | 120 | Wikipedia text & crowd-sourced speech | Narrated | CC0-1.0 | `"sentence"` | `"path"` |
|
181 |
+
| [VoxPopuli](https://huggingface.co/datasets/facebook/voxpopuli) | 15 | European Parliament recordings | Spontaneous | CC0 | `"normalized_text"` | `"audio_id"` |
|
182 |
+
|
183 |
+
To achieve *robustness* to different distributions of audio data, it is recommended to train on multiple datasets where possible.
|
184 |
+
For example, the above three datasets all have splits for the German language. Thus, if distilling a Whisper model for German,
|
185 |
+
it would be wise to use a combination of the three datasets during training, in order to cover at least three distinct domains
|
186 |
+
(audiobooks, crowd-sourced speech, parliament recordings). You may wish to use a combination of open-source datasets, or
|
187 |
+
a combination of open-source and individually owned datasets to cover multiple distributions and domains. Moreover, if you were to train on low-resource datasets (<500 hours), you could experiment with [language mixing](#3-language-mixing) to improve robustness.
|
188 |
+
|
189 |
+
## 2. Initialisation
|
190 |
+
|
191 |
+
The script [`create_student_model.py`](create_student_model.py) can be used to initialise a small student model
|
192 |
+
from a large teacher model. When initialising a student model with fewer layers than the teacher model, the student is
|
193 |
+
initialised by copying maximally spaced layers from the teacher, as per the [DistilBart](https://arxiv.org/abs/2010.13002)
|
194 |
+
recommendations.
|
195 |
+
|
196 |
+
First, we need to create a model repository on the Hugging Face Hub. This repository will contain all the required files
|
197 |
+
to reproduce the training run, alongside model weights, training logs and a README.md card. You can either create a model
|
198 |
+
repository directly on the Hugging Face Hub using the link: https://huggingface.co/new. Or, via the CLI, as we'll show here.
|
199 |
+
|
200 |
+
Let's pick a name for our distilled model: `distil-whisper-large-v3-hi`. We can run the following command to create a repository under this name:
|
201 |
+
|
202 |
+
```bash
|
203 |
+
huggingface-cli repo create distil-whisper-large-v3-hi
|
204 |
+
```
|
205 |
+
|
206 |
+
We can now see the model on the Hub, e.g. under https://huggingface.co/sanchit-gandhi/distil-whisper-large-v3-hi
|
207 |
+
|
208 |
+
Let's clone the repository so that we can place our training script and model weights inside:
|
209 |
+
|
210 |
+
```bash
|
211 |
+
git lfs install
|
212 |
+
git clone https://huggingface.co/sanchit-gandhi/distil-whisper-large-v3-hi
|
213 |
+
```
|
214 |
+
|
215 |
+
Be sure to change the repo address to `https://huggingface.co/<your-user-name>/<your-repo-name>`
|
216 |
+
|
217 |
+
We can now copy the relevant training scrips to the repository:
|
218 |
+
```bash
|
219 |
+
cd distil-whisper-large-v3-hi
|
220 |
+
|
221 |
+
cp ../distil-whisper/training/create_student_model.py .
|
222 |
+
cp ../distil-whisper/training/run_distillation.py .
|
223 |
+
```
|
224 |
+
|
225 |
+
The following command demonstrates how to initialise a student model from the Whisper [large-v3](https://huggingface.co/openai/whisper-large-v3)
|
226 |
+
checkpoint, with all 32 encoder layer and 2 decoder layers. The 2 student decoder layers are copied from teacher layers
|
227 |
+
1 and 32 respectively, as the maximally spaced layers:
|
228 |
+
|
229 |
+
```bash
|
230 |
+
#!/usr/bin/env bash
|
231 |
+
|
232 |
+
python create_student_model.py \
|
233 |
+
--teacher_checkpoint "openai/whisper-large-v3" \
|
234 |
+
--encoder_layers 32 \
|
235 |
+
--decoder_layers 2 \
|
236 |
+
--save_dir "./distil-large-v3-init"
|
237 |
+
```
|
238 |
+
|
239 |
+
The initialised model will be saved to the sub-directory `distil-large-v3-init` in our model repository.
|
240 |
+
|
241 |
+
|
242 |
+
**Note:** You can leverage language transfer by setting `--teacher_checkpoint` to "distil-whisper/distil-large-v3", see [language transfer](#22-language-transfer) for more details.
|
243 |
+
|
244 |
+
## 3. Training
|
245 |
+
|
246 |
+
The script [`run_distillation.py`](run_distillation.py) is an end-to-end script for loading multiple
|
247 |
+
datasets, a student model, a teacher model, and performing teacher-student distillation. It uses the loss formulation
|
248 |
+
from the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430), which is a weighted sum of the cross-entropy and
|
249 |
+
KL-divergence loss terms.
|
250 |
+
|
251 |
+
The following command takes the Common Voice dataset that was pseudo-labelled in the first stage and trains the
|
252 |
+
2-layer decoder model intialised in the previous step. We pass the local path to the pseudo-labelled Common Voice dataset
|
253 |
+
(`../common_voice_16_1_hi_pseudo_labelled`), which you can change to the path where your local pseudo-labelled dataset is
|
254 |
+
saved.
|
255 |
+
|
256 |
+
In this example, we will combine the train and validation splits to give our training set, and evaluate on the test split
|
257 |
+
only. This is purely to demonstrate how to combine multiple pseudo-labelled datasets for training, rather than recommended
|
258 |
+
advice for defining train/validation splits. We advise that you train on the train splits of your dataset, evaluate and
|
259 |
+
tune hyper-parameters on the validation split, and only test the final checkpoint on the test split. Note how multiple
|
260 |
+
training datasets and splits can be loaded by separating the dataset arguments by `+` symbols. Thus, the script generalises
|
261 |
+
to any number of training datasets.
|
262 |
+
|
263 |
+
```bash
|
264 |
+
#!/usr/bin/env bash
|
265 |
+
|
266 |
+
accelerate launch run_distillation.py \
|
267 |
+
--model_name_or_path "./distil-large-v3-init" \
|
268 |
+
--teacher_model_name_or_path "openai/whisper-large-v3" \
|
269 |
+
--train_dataset_name "../common_voice_16_1_hi_pseudo_labelled+../common_voice_16_1_hi_pseudo_labelled" \
|
270 |
+
--train_split_name "train+validation" \
|
271 |
+
--text_column_name "sentence+sentence" \
|
272 |
+
--train_dataset_samples "7+4" \
|
273 |
+
--eval_dataset_name "../common_voice_16_1_hi_pseudo_labelled" \
|
274 |
+
--eval_split_name "test" \
|
275 |
+
--eval_text_column_name "sentence" \
|
276 |
+
--eval_steps 1000 \
|
277 |
+
--save_steps 1000 \
|
278 |
+
--warmup_steps 50 \
|
279 |
+
--learning_rate 0.0001 \
|
280 |
+
--lr_scheduler_type "constant_with_warmup" \
|
281 |
+
--timestamp_probability 0.2 \
|
282 |
+
--condition_on_prev_probability 0.2 \
|
283 |
+
--language "hi" \
|
284 |
+
--task "transcribe" \
|
285 |
+
--logging_steps 25 \
|
286 |
+
--save_total_limit 1 \
|
287 |
+
--max_steps 5000 \
|
288 |
+
--wer_threshold 20 \
|
289 |
+
--per_device_train_batch_size 32 \
|
290 |
+
--per_device_eval_batch_size 32 \
|
291 |
+
--dataloader_num_workers 8 \
|
292 |
+
--preprocessing_num_workers 8 \
|
293 |
+
--ddp_timeout 7200 \
|
294 |
+
--dtype "bfloat16" \
|
295 |
+
--attn_implementation "sdpa" \
|
296 |
+
--output_dir "./" \
|
297 |
+
--do_train \
|
298 |
+
--do_eval \
|
299 |
+
--gradient_checkpointing \
|
300 |
+
--overwrite_output_dir \
|
301 |
+
--predict_with_generate \
|
302 |
+
--freeze_encoder \
|
303 |
+
--freeze_embed_positions \
|
304 |
+
--streaming False \
|
305 |
+
--push_to_hub
|
306 |
+
|
307 |
+
```
|
308 |
+
|
309 |
+
The above training script will take approximately 3 hours to complete on an 80 GB A100 GPU and yield a final WER of 76%.
|
310 |
+
While the generations are starting to take form, there is still a 59% WER gap to the teacher model. This is hardly
|
311 |
+
surprising give we only have 15 hours of un-filtered data, and closer to just 1.5 hours with data filtering.
|
312 |
+
As mentioned above, using upwards of 1000 hours of data and training for 10k steps will likely yield
|
313 |
+
more competitive performance. For the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430), we trained on 21k hours
|
314 |
+
of audio data for 80k steps. We found that upwards of 13k hours of audio data was required to reach convergence on English
|
315 |
+
ASR (see Section 9.2 of the [paper](https://arxiv.org/abs/2311.00430)), so the more data you have, the better!
|
316 |
+
|
317 |
+
Scaling to multiple GPUs using [distributed data parallelism (DDP)](https://pytorch.org/tutorials/beginner/ddp_series_theory.html)
|
318 |
+
is trivial: simply run `accelerate config` and select the multi-GPU option, specifying the IDs of the GPUs you wish to use. The
|
319 |
+
above script can then be run using DDP with no code changes.
|
320 |
+
|
321 |
+
Training logs will be reported to TensorBoard and WandB, provided the relevant packages are available. An example of a
|
322 |
+
saved checkpoint pushed to the Hugging Face Hub can be found here: [sanchit-gandhi/distil-whisper-large-v3-hi](https://huggingface.co/sanchit-gandhi/distil-whisper-large-v3-hi).
|
323 |
+
|
324 |
+
There are a few noteworthy data arguments:
|
325 |
+
1. `train_dataset_samples`: defines the number of training samples in each dataset. Used to calculate the sampling probabilities in the dataloader. A good starting point is setting the samples to the number of hours of audio data in each split. A more refined strategy is setting it to the number of training samples in each split, however this might require downloading the dataset offline to compute these statistics.
|
326 |
+
2. `wer_threshold`: sets the WER threshold between the normalised pseudo-labels and normalised ground truth labels. Any samples with WER > `wer_threshold` are discarded from the training data. This is beneficial to avoid training the student model on pseudo-labels where Whisper hallucinated or got the predictions grossly wrong. In our English distillation experiments, we found a WER threshold of 10% provides the optimal trade-off between ensuring high-quality transcriptions, and not filtering unnecessary amounts of training data. For multilingual distillation, the threshold should be set in accordance with the WER achieved by the pre-trained model on the test set.
|
327 |
+
3. `streaming`: whether or not to use Datasets' streaming mode. Recommended for large datasets, where the audio data can be streamed from the Hugging Face Hub with no disk space requirements.
|
328 |
+
4. `timestamp_probability`: the per-sample probability for retaining timestamp tokens in the labels (should they contain them). Retaining some portion of timestamp tokens in the training data is required to ensure the distilled model can predict timestamps at inference time. In our experiments, we found that training on timestamps with high-probability hurts the distilled model's transcription performance. Thus, we recommend setting this to a value below 0.5. Typically, a value of 0.2 works well, giving good transcription and timestamp performance.
|
329 |
+
5. `condition_on_prev_probability`: the per-sample probability for conditioning on previous labels. Conditioning on previous tokens is required to ensure the distilled model can be used with the "sequential" long-form transcription algorithm at inference time. We did not experiment with this parameter, but found values around 0.2 to provide adequate performance. OpenAI pre-trained Whisper on with a 50% probability for conditioning on previous tokens. Thus, you might wish to try higher values.
|
330 |
+
|
331 |
+
As well as a few noteworthy model arguments that can be configured to give optimal training performance:
|
332 |
+
1. `freeze_encoder`: whether to freeze the entire encoder of the student model during training. Beneficial when the student encoder is copied exactly from the teacher encoder. In this case, the encoder hidden-states from the teacher model are re-used for the student model. Stopping the gradient computation through the encoder and sharing the encoder hidden-states provides a significant memory saving, and can enable up to 2x batch sizes.
|
333 |
+
2. `freeze_embed_positions`: whether to freeze the student model's decoder positional embeddings. Using the same embed positions as the teacher model, which is designed to handle context lengths up to 448 tokens, helps the student model retain its input id representation up to the full max input length.
|
334 |
+
3. `dtype`: data type (dtype) in which the model computation should be performed. Note that this only controls the dtype of the computations (forward and backward pass), and not the dtype of the parameters or optimiser states.
|
335 |
+
4. `freeze_decoder`: whether to freeze the student model's decoder. Note that the input tokens embeddings and language modelling head will remain trainable.
|
336 |
+
|
337 |
+
And finally, a few noteworthy training arguments:
|
338 |
+
1. `max_steps`: defines the total number of optimisation steps (forward + backward pass) during training. To reach convergence, you should use a dataset of at least 1k hours and train for a minimum of 50k steps.
|
339 |
+
2. `lr_scheduler_stype`: defines the learning rate schedule, one of `constant_with_warmup` or `linear`. When experimenting with a training set-up or training for very few steps (< 5k), using `constant_with_warmup` is typically beneficial, since the learning rate remains high over the short training run. When performing long training runs (> 5k), using a `linear` schedule generally results in superior downstream performance of the distilled model.
|
340 |
+
|
341 |
+
TODO:
|
342 |
+
- [ ] Template for model cards
|
343 |
+
|
344 |
+
## 4. Evaluation
|
345 |
+
|
346 |
+
There are four types of evaluation performed in Distil-Whisper:
|
347 |
+
1. Short form: evaluation on audio samples less than 30s in duration. Examples include typical ASR test sets, such as the LibriSpeech validation set.
|
348 |
+
2. Sequential long form: evaluation on audio samples longer than 30s in duration using the original "sequential" long-form algorithm. Examples include entire TED talks or earnings calls.
|
349 |
+
3. Chunked long form: evaluation on audio samples longer than 30s in duration using the Transformers "chunked" long-form algorithm.
|
350 |
+
4. Speculative decoding: evaluation on audio samples less than 30s in duration, where a faster, distilled model is used as the assistant to a slower, teacher model.
|
351 |
+
|
352 |
+
All four forms of evaluation are performed using the script [`run_eval.py`](run_eval.py). Unlike the pseudo-labelling
|
353 |
+
and training scripts, the evaluation script assumes that only one GPU accelerator is used. We can copy the corresponding
|
354 |
+
evaluation script to the model repository using the following command:
|
355 |
+
|
356 |
+
```bash
|
357 |
+
cp ../distil-whisper/training/run_eval.py .
|
358 |
+
```
|
359 |
+
|
360 |
+
Models are assessed jointly using:
|
361 |
+
1. The *word-error rate (WER)* metric: measures the number of substitution, deletion and insertion errors relative to the total number of words. A lower WER indicates a more accurate model.
|
362 |
+
2. The *inverse real-time factor (RTFx)* metric: measures the ratio of `audio input time : model compute time`. A higher RTFx indicates a faster model. Note that this metric is WER-dependent, meaning that it makes sense to compare two models' *RTFx* only at fixed *WER* performances. Indeed, deletions could lead to early stopping of token generation, resulting in higher *WER* and lower *RTFx*.
|
363 |
+
3. Token generation speed: This refers to the number of tokens generated per second. As with *RTFx*, this metric is dependent on the *WER* since token generation time is not linear. By default, this metric is calculated by averaging the total number of `generated tokens : generation time` (full forward pass of the model) when evaluating on the given test set. However, using the `--precise_tok_generation` flag will compute this metric separately for a fixed number of tokens.
|
364 |
+
|
365 |
+
In all cases, it is particularly important to evaluate the final model on data that is *out-of-distribution (OOD)* with
|
366 |
+
the training data. Evaluating on OOD data provides insight as to how well the distilled model is likely to generalise to
|
367 |
+
different audio distributions at inference time. In our example, the Common Voice test set is *in-distribution (ID)*
|
368 |
+
with our training data, since it is taken from the same distribution as the Common Voice training set. Whereas the FLEURS
|
369 |
+
test set is OOD, since it is not used as part of the training set. See [Datasets](#1-datasets) section for recommendations.
|
370 |
+
|
371 |
+
### Short Form
|
372 |
+
|
373 |
+
The script [`run_eval.py`](run_eval.py) can be used to evaluate a trained student model over multiple short-form
|
374 |
+
validation sets. The following example demonstrates how to evaluate the student model trained in the previous step on
|
375 |
+
the Common Voice `test` set (ID) and also the FLEURS `test` set (OOD). Again, it leverages streaming mode to bypass
|
376 |
+
the need to download the data offline:
|
377 |
+
|
378 |
+
```bash
|
379 |
+
#!/usr/bin/env bash
|
380 |
+
|
381 |
+
python run_eval.py \
|
382 |
+
--model_name_or_path "./" \
|
383 |
+
--dataset_name "../common_voice_16_1_hi_pseudo_labelled+google/fleurs" \
|
384 |
+
--dataset_config_name "default+hi_in" \
|
385 |
+
--dataset_split_name "test+test" \
|
386 |
+
--text_column_name "sentence+transcription" \
|
387 |
+
--batch_size 16 \
|
388 |
+
--dtype "bfloat16" \
|
389 |
+
--generation_max_length 256 \
|
390 |
+
--language "hi" \
|
391 |
+
--attn_implementation "sdpa" \
|
392 |
+
--streaming
|
393 |
+
|
394 |
+
```
|
395 |
+
|
396 |
+
The student model achieves an average WER of TODO% with an RTFx of TODO for a batch size of 16. We can easily adapt the above
|
397 |
+
script to evaluate the teacher model, simply by switching the `model_name_or_path` to `openai/whisper-large-v3`, which
|
398 |
+
achieves an average WER of TODO% with an RTFx of TODO. Therefore, for a batch size of 16, the student model is a factor of TODO
|
399 |
+
times faster than the teacher. The WER gap can be closed by training on more data (at least 1k hours) for more training
|
400 |
+
steps (at least 50k).
|
401 |
+
|
402 |
+
### Sequential Long Form
|
403 |
+
|
404 |
+
The original Whisper paper presents a long-form transcription algorithm that sequentially transcribes 30-second segments
|
405 |
+
of audio and shifts the sliding window according to the timestamps predicted by the model. This style of sequential
|
406 |
+
inference is performed directly using the [`.generate`](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate)
|
407 |
+
method in Transformers.
|
408 |
+
|
409 |
+
The script [`run_eval.py`](run_eval.py) can be used to evaluate the trained student model on an arbitrary number of
|
410 |
+
long-form evaluation sets using the sequential algorithm. Since we don't have a long-form validation set for Hindi to hand,
|
411 |
+
in this example we'll evaluate the official Distil-Whisper model [`distil-large-v3`](https://huggingface.co/distil-whisper/distil-large-v3)
|
412 |
+
on the TED-LIUM validation set:
|
413 |
+
|
414 |
+
```bash
|
415 |
+
#!/usr/bin/env bash
|
416 |
+
|
417 |
+
accelerate launch run_eval.py \
|
418 |
+
--model_name_or_path "distil-whisper/distil-large-v3" \
|
419 |
+
--dataset_name "distil-whisper/tedlium-long-form" \
|
420 |
+
--dataset_config_name "default" \
|
421 |
+
--dataset_split_name "validation" \
|
422 |
+
--text_column_name "text" \
|
423 |
+
--batch_size 16 \
|
424 |
+
--dtype "bfloat16" \
|
425 |
+
--generation_max_length 256 \
|
426 |
+
--language "en" \
|
427 |
+
--attn_implementation "sdpa" \
|
428 |
+
--streaming
|
429 |
+
|
430 |
+
```
|
431 |
+
|
432 |
+
### Chunked Long Form
|
433 |
+
|
434 |
+
Chunked long form evaluation runs on the premise that a single long audio file can be *chunked* into smaller segments and
|
435 |
+
inferred in parallel. The resulting transcriptions are then joined at the boundaries to give the final text prediction.
|
436 |
+
A small overlap (or *stride*) is used between adjacent segments to ensure a continuous transcription across chunks.
|
437 |
+
|
438 |
+
This style of chunked inference is performed using the [`pipeline`](https://huggingface.co/docs/transformers/main_classes/pipelines)
|
439 |
+
class, which provides a wrapper around the [`.generate`](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate)
|
440 |
+
function for long-form inference.
|
441 |
+
|
442 |
+
The script [`run_eval.py`](run_eval.py) can be used to evaluate the trained student model on an arbitrary number of
|
443 |
+
long-form evaluation sets using the pipeline class. Again, in this example we'll evaluate distil-large-v3 on the
|
444 |
+
TED-LIUM validation set:
|
445 |
+
|
446 |
+
```bash
|
447 |
+
#!/usr/bin/env bash
|
448 |
+
|
449 |
+
python run_eval.py \
|
450 |
+
--model_name_or_path "openai/whisper-large-v3" \
|
451 |
+
--dataset_name "distil-whisper/tedlium-long-form" \
|
452 |
+
--dataset_config_name "default" \
|
453 |
+
--dataset_split_name "validation" \
|
454 |
+
--text_column_name "text" \
|
455 |
+
--use_pipeline \
|
456 |
+
--chunk_length_s 25.0 \
|
457 |
+
--language "en" \
|
458 |
+
--return_timestamps \
|
459 |
+
--dtype "bfloat16" \
|
460 |
+
--streaming
|
461 |
+
|
462 |
+
```
|
463 |
+
|
464 |
+
The argument `chunk_length_s` controls the length of the chunked audio samples. It should be set to match the typical
|
465 |
+
length of audio the student model was trained on. If unsure about what value of `chunk_length_s` is optimal for your case,
|
466 |
+
it is recommended to run a *sweep* over all possible values. A template script for running a [WandB sweep](https://docs.wandb.ai/guides/sweeps)
|
467 |
+
can be found under [`run_chunk_length_s_sweep.yaml`](flax/long_form_transcription_scripts/run_chunk_length_s_sweep.yaml).
|
468 |
+
|
469 |
+
### Speculative Decoding
|
470 |
+
|
471 |
+
Speculative decoding, or assisted generation, relies on the premise that a faster, assistant model can be used to speed-up
|
472 |
+
the generation of a slower, assistant model. Speculative decoding mathematically ensures that exactly the same outputs as
|
473 |
+
Whisper are obtained, while being ~2 times faster. This makes it the perfect drop-in replacement for existing Whisper
|
474 |
+
pipelines, since exactly the same outputs are guaranteed.
|
475 |
+
|
476 |
+
Distil-Whisper checkpoints can be designed to be efficient assistant models to Whisper for speculative decoding. More precisely,
|
477 |
+
by freezing the encoder during training, the distilled model can share the same encoder weights as Whisper during inference, since
|
478 |
+
the encoder weights are un-changed. In doing so, only the distilled 2-layer decoder has to be loaded in addition to the
|
479 |
+
original Whisper model, which is approximately an 8% increase to the total parameter count, with up to 2x faster inference
|
480 |
+
for low batch sizes. For more details on speculative decoding, the reader is advised to refer to the following blog post:
|
481 |
+
[Speculative Decoding for 2x Faster Whisper Inference](https://huggingface.co/blog/whisper-speculative-decoding).
|
482 |
+
|
483 |
+
In the example below, we use our distilled model as an assistant to the large-v3 teacher model during inference:
|
484 |
+
|
485 |
+
```bash
|
486 |
+
#!/usr/bin/env bash
|
487 |
+
|
488 |
+
python run_eval.py \
|
489 |
+
--model_name_or_path "openai/whisper-large-v3" \
|
490 |
+
--assistant_model_name_or_path "./" \
|
491 |
+
--dataset_name "../common_voice_16_1_hi_pseudo_labelled+google/fleurs" \
|
492 |
+
--dataset_config_name "default+hi_in" \
|
493 |
+
--dataset_split_name "test+test" \
|
494 |
+
--text_column_name "sentence+transcription" \
|
495 |
+
--batch_size 16 \
|
496 |
+
--dtype "bfloat16" \
|
497 |
+
--generation_max_length 256 \
|
498 |
+
--language "hi" \
|
499 |
+
--attn_implementation "sdpa" \
|
500 |
+
--streaming
|
501 |
+
|
502 |
+
```
|
503 |
+
|
504 |
+
We see that we achieve a WER of TODO%, the same as what we obtained with the large-v3 model, but with an RTFx of TODO,
|
505 |
+
a factor of TODO faster than using the large-v3 model alone. The RTFx value can be improved by training the student on
|
506 |
+
more data and for more training steps, since this will improve the number of predicted tokens that match the teacher
|
507 |
+
predictions.
|
508 |
+
|
509 |
+
## Recommendations and guidelines
|
510 |
+
|
511 |
+
### 1. Datasets
|
512 |
+
|
513 |
+
As explained, ideally, you should aim for ~1000 hours of audio data for training a distilled model via KD. Moreover, you should evaluate your model on out-of-distribution test sets to assess generalization capacities. With at least 1500 hours of audio data for German, Dutch, French and Spanish, 600 hours for Italian, and 300 hours for Portuguese and Polish (which can be supplemented with your own datasets), a good setup to start with is:
|
514 |
+
- **Training datasets:** [Common Voice 17](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0) and [Multilingual Librispeech](https://huggingface.co/datasets/facebook/multilingual_librispeech). Use the `train` split for training, and the `validation` and `test` splits for in-distribution testing.
|
515 |
+
- **Test datasets:** [VoxPopuli](https://huggingface.co/datasets/facebook/voxpopuli) and [Fleurs](https://huggingface.co/datasets/google/fleurs). Use the `validation` and `test` splits for out-of-distribution testing.
|
516 |
+
|
517 |
+
### 2. Student model's decoder
|
518 |
+
#### 2.1 Number of Decoder Layers
|
519 |
+
|
520 |
+
We recommend using a 2-layers decoder (see language transfer below). However, you can adjust the number of decoder layers when initializing the student model to balance between inference speed and accuracy. Experimentation has revealed that the Pareto optimal points are with 2, 3, and 4-layers decoders. For indicative results, after 10,000 training steps and inference on an 80GB Nvidia H100 with a batch size of 16 and 20 tokens generation, compared to [Whiper *large-v3*](https://huggingface.co/openai/whisper-large-v3) baseline:
|
521 |
+
|
522 |
+
<center>
|
523 |
+
|
524 |
+
| | rel. token gen. speed | ΔWER(%) |
|
525 |
+
|----------|:-------------:|------:|
|
526 |
+
| 2 layers | $3.66$ | $-3.5$ |
|
527 |
+
| 3 layers | $3.35$ | $-2.3$ |
|
528 |
+
| 4 layers | $3.11$ | $-1.8$ |
|
529 |
+
|
530 |
+
</center>
|
531 |
+
|
532 |
+
|
533 |
+
#### 2.2 Language Transfer
|
534 |
+
|
535 |
+
If you opt for a 2-layers decoder, consider leveraging language transfer by initializing the student model from the [distil-large-v3 English distilled model](https://huggingface.co/distil-whisper/distil-large-v3). For French, this method has shown performance improvements of ΔWER=-1.9% (compared to a 2-layers decoder initialized from [Whiper *large-v3*](https://huggingface.co/openai/whisper-large-v3)) after 10,000 training steps.
|
536 |
+
|
537 |
+
```diff
|
538 |
+
- --teacher_checkpoint "openai/whisper-large-v3" \
|
539 |
+
+ --teacher_checkpoint "distil-whisper/distil-large-v3" \
|
540 |
+
```
|
541 |
+
|
542 |
+
### 3. Language mixing
|
543 |
+
|
544 |
+
If you're working with low-resource languages (<500 hours of audio data), consider mixing your training data with a closely related language (for example, mix French and Spanish) to leverage knowledge transfer between languages. Experiments showed that mixing ~400 hours of French (which resulted in a model with poor generalization capacities) with ~500 hours of Spanish improved the model's out-of-distribution performance on French by ΔWER=-7.5%.
|
545 |
+
|
546 |
+
To do this:
|
547 |
+
1. Run [pseudo labeling](#1-pseudo-labelling) for each training dataset, setting the `--language` flag to the language of the respective dataset. In the example of mixing French and Spanish, simply modify the given [pseudo labeling](#1-pseudo-labelling) command with:
|
548 |
+
* pseudo labelling the French dataset
|
549 |
+
```diff
|
550 |
+
- --dataset_config_name "hi" \
|
551 |
+
- --output_dir "./common_voice_16_1_hi_pseudo_labelled" \
|
552 |
+
- --language "hi" \
|
553 |
+
+ --dataset_config_name "fr" \
|
554 |
+
+ --output_dir "./common_voice_16_1_fr_pseudo_labelled" \
|
555 |
+
+ --language "fr" \
|
556 |
+
```
|
557 |
+
* pseudo labelling the Spanish dataset
|
558 |
+
```diff
|
559 |
+
- --dataset_config_name "hi" \
|
560 |
+
- --output_dir "./common_voice_16_1_hi_pseudo_labelled" \
|
561 |
+
- --language "hi" \
|
562 |
+
+ --dataset_config_name "es" \
|
563 |
+
+ --output_dir "./common_voice_16_1_es_pseudo_labelled" \
|
564 |
+
+ --language "es" \
|
565 |
+
```
|
566 |
+
|
567 |
+
2. Conduct [training](#3-training) on these pseudo-labeled datasets, using the `--language` flag set to your targeted language. Note that this flag is only used for evaluation purposes, so you set it to the targeted language. The language token used for forwarding the teacher and student model decoders is the one used and saved in pseudo labels during pseudo-labeling, ensuring it's the correct one for the considered sample. In the example of mixing French and Spanish, simply modify the given [training](#1-pseudo-labelling) command with:
|
568 |
+
```diff
|
569 |
+
- --train_dataset_name "../common_voice_16_1_hi_pseudo_labelled+../common_voice_16_1_hi_pseudo_labelled" \
|
570 |
+
- --train_split_name "train+validation" \
|
571 |
+
- --eval_dataset_name "../common_voice_16_1_hi_pseudo_labelled" \
|
572 |
+
- --eval_split_name "test" \
|
573 |
+
+ --train_dataset_name "../common_voice_17_0_fr_pseudo_labelled+../common_voice_17_0_es_pseudo_labelled" \
|
574 |
+
+ --train_split_name "train+train" \
|
575 |
+
+ --eval_dataset_name "../common_voice_16_1_fr_pseudo_labelled" \
|
576 |
+
+ --eval_split_name "validation" \
|
577 |
+
```
|
578 |
+
|
579 |
+
## Overview of Training Methods
|
580 |
+
|
581 |
+
### 1. Fine-Tuning
|
582 |
+
|
583 |
+
For fine-tuning, we take the original Whisper checkpoint and train it on one or more datasets using the standard
|
584 |
+
cross-entropy loss. As such, there is no involvement from the teacher checkpoint during training, and so the fine-tuned
|
585 |
+
model is permitted to *overfit* to the distribution of the training data we provide. This makes it appealing for "low-resource"
|
586 |
+
languages where the original Whisper model performs poorly, since we can boost the performance of the model on a single
|
587 |
+
language by *overfitting* to that distribution of data. Note that this means the fine-tuned model is prone to loosing
|
588 |
+
its robustness to different audio distributions, which is the trade-off with improving performance on a specified dataset.
|
589 |
+
|
590 |
+
As a rule of thumb, fine-tuning is appropriate for languages where the original Whisper model performs > 20% WER, and we
|
591 |
+
have a relatively small quantity of training data available (< 1000 hours). With fine-tuning, we require as little as **10 hours**
|
592 |
+
of training data to significantly boost the performance of the Whisper model. For an in-depth guide to fine-tuning Whisper,
|
593 |
+
the reader is advised to refer to the blog post: [Fine-Tune Whisper For Multilingual ASR with 🤗 Transformers](https://huggingface.co/blog/fine-tune-whisper).
|
594 |
+
|
595 |
+
### 2. Shrink and Fine-Tune
|
596 |
+
|
597 |
+
Shrink and fine-tune (SFT) is a knowledge distillation (KD) technique in which we first *shrink* the teacher model to a
|
598 |
+
smaller student model by copying maximally spaced layers, and then *fine-tune* the student model on the cross-entropy loss
|
599 |
+
as described above. Typically, we retain the full encoder from the Whisper model and only shrink the decoder. Retaining
|
600 |
+
the entire encoder helps significantly with maintaining Whisper's robustness to different audio distributions (_c.f._
|
601 |
+
Section 9.3 of the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430)).
|
602 |
+
|
603 |
+
We can either train the student model on a dataset of (audio, text) pairs as above. Or, we can use the pre-trained
|
604 |
+
Whisper model to generate *pseudo-labels* for our audio data, and train on the (audio, pseudo-label) pairs.
|
605 |
+
|
606 |
+
Pseudo-labels can be used when either:
|
607 |
+
1. The original text transcriptions are normalised (lower-cased or no punctuation): the Whisper generated pseudo-labels contain both punctuation and casing, and so can be used as a substitute for the normalised transcriptions
|
608 |
+
2. The pre-trained Whisper model achieves < 20% WER on the languages: we then know the majority of the pseudo-labels will be accurate enough for us to train on.
|
609 |
+
|
610 |
+
They are not recommended when both of the following are true:
|
611 |
+
1. The original text is punctuated and cased
|
612 |
+
2. The pre-trained Whisper model achieves > 20% WER on the languages: in this case, we want to overfit to the particular distribution of the language, and so train directly on the original text data
|
613 |
+
|
614 |
+
To discard inaccurate pseudo-labels during training, we employ a simple WER heuristic to filter our pseudo-labelled
|
615 |
+
training data. We first normalise the original text and the pseudo-labelled text using the Whisper normaliser. If the
|
616 |
+
WER between the normalised text exceeds a 10% WER threshold, we discard the training sample. Else, we retain it for training.
|
617 |
+
Section 9.1 of the Distil-Whisper [paper](https://arxiv.org/abs/2311.00430) demonstrates the importance of using this
|
618 |
+
threshold for training.
|
619 |
+
|
620 |
+
### 3. KL Divergence
|
621 |
+
|
622 |
+
In the KL Divergence setting, the student model is initialised by shrinking the teacher as before, and then trained to
|
623 |
+
match the predictions of the teacher during training.
|
624 |
+
|
625 |
+
### Summary of Methods
|
626 |
+
|
627 |
+
The following table summarises the two training paradigms: fine-tuning and knowledge distillation (KD). It suggests
|
628 |
+
minimum values for the pre-trained WER / training data to achieve reasonable performance:
|
629 |
+
|
630 |
+
| Method | Pre-Trained WER / % | Training Data / h |
|
631 |
+
|-------------|---------------------|-------------------|
|
632 |
+
| Fine-tuning | > 20 | < 1000 |
|
633 |
+
| KD | < 20 | > 1000 |
|
634 |
+
|
635 |
+
## Acknowledgements
|
636 |
+
|
637 |
+
* OpenAI for the Whisper [model](https://huggingface.co/openai/whisper-large-v3) and [original codebase](https://github.com/openai/whisper)
|
638 |
+
* Hugging Face 🤗 [Transformers](https://github.com/huggingface/transformers) for the Whisper model implementation
|
639 |
+
* Google's [TPU Research Cloud (TRC)](https://sites.research.google/trc/about/) program for Cloud TPU v4s used to train the official Distil-Whisper models
|
640 |
+
* The Hugging Face 🤗 cluster for enabling experimentation with the PyTorch scripts
|
641 |
+
|
642 |
+
## Citation
|
643 |
+
|
644 |
+
If you use this code-base, please consider citing the Distil-Whisper paper:
|
645 |
+
|
646 |
+
```
|
647 |
+
@misc{gandhi2023distilwhisper,
|
648 |
+
title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling},
|
649 |
+
author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
|
650 |
+
year={2023},
|
651 |
+
eprint={2311.00430},
|
652 |
+
archivePrefix={arXiv},
|
653 |
+
primaryClass={cs.CL}
|
654 |
+
}
|
655 |
+
```
|
distil_whisper.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
README.md
|
2 |
+
pyproject.toml
|
3 |
+
setup.py
|
4 |
+
distil_whisper.egg-info/PKG-INFO
|
5 |
+
distil_whisper.egg-info/SOURCES.txt
|
6 |
+
distil_whisper.egg-info/dependency_links.txt
|
7 |
+
distil_whisper.egg-info/requires.txt
|
8 |
+
distil_whisper.egg-info/top_level.txt
|
distil_whisper.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
distil_whisper.egg-info/requires.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=1.10
|
2 |
+
transformers>=4.35.1
|
3 |
+
datasets[audio]>=2.14.7
|
4 |
+
accelerate>=0.24.1
|
5 |
+
jiwer
|
6 |
+
evaluate>=0.4.1
|
7 |
+
wandb
|
8 |
+
tensorboard
|
9 |
+
nltk
|
10 |
+
|
11 |
+
[dev]
|
12 |
+
ruff==0.1.5
|
distil_whisper.egg-info/top_level.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
flax/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
flax/Makefile
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
check_dirs := .
|
2 |
+
|
3 |
+
quality:
|
4 |
+
black --check $(check_dirs)
|
5 |
+
ruff $(check_dirs)
|
6 |
+
|
7 |
+
style:
|
8 |
+
black $(check_dirs)
|
9 |
+
ruff $(check_dirs) --fix
|
flax/README.md
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Reproducing Distil-Whisper
|
2 |
+
|
3 |
+
This sub-folder contains all the training and inference scripts to reproduce the Distil-Whisper project. Distil-Whisper
|
4 |
+
is written in JAX to leverage the fast training and inference speed offered by TPU v4 hardware. However, it also works
|
5 |
+
efficiently on GPU hardware without any additional code changes.
|
6 |
+
|
7 |
+
Reproducing the Distil-Whisper project requires four stages to be completed in successive order:
|
8 |
+
|
9 |
+
1. [Pseudo-labelling](#pseudo-labelling)
|
10 |
+
2. [Initialisation](#initialisation)
|
11 |
+
3. [Training](#training)
|
12 |
+
4. [Evaluation](#evaluation)
|
13 |
+
|
14 |
+
This README is partitioned according to the four stages. Each section provides a minimal example for running the
|
15 |
+
scripts used in the project. The final scripts used to train the model are referenced in-line.
|
16 |
+
|
17 |
+
It is worth noting that the experiments performed in JAX/Flax have been on English ASR only. For multilingual training code,
|
18 |
+
the [PyTorch Training Code](../README.md) can easily be used, facilitating anyone to run Whisper distillation on a language of their choice.
|
19 |
+
|
20 |
+
## Requirements
|
21 |
+
|
22 |
+
Distil-Whisper is written in Python, JAX and Flax, and heavily leverages the Flax Whisper implementation in
|
23 |
+
[🤗 Transformers](https://github.com/huggingface/transformers). The instructions for installing the package are as follows:
|
24 |
+
1. Install JAX from the [official instructions](https://github.com/google/jax#installation), ensuring you install the correct version for your hardware (GPU or TPU).
|
25 |
+
2. Install the `distil_whisper` package by cloning the repository and performing an editable installation:
|
26 |
+
|
27 |
+
```bash
|
28 |
+
git clone https://github.com/huggingface/distil-whisper.git
|
29 |
+
cd distil-whisper/training/flax
|
30 |
+
pip install -e .
|
31 |
+
```
|
32 |
+
|
33 |
+
## Pseudo-Labelling
|
34 |
+
|
35 |
+
Pseudo-labelling is the process of generating target text predictions for the input audio data using the teacher model.
|
36 |
+
The generated text labels then replace the ground truth text labels when performing distillation. The rationale for
|
37 |
+
using pseudo-labels instead of ground truth labels is to circumvent the issue of inconsistent transcription formatting
|
38 |
+
across datasets.
|
39 |
+
|
40 |
+
The python script [`run_pseudo_labelling.py`](run_pseudo_labelling.py) is a flexible inference script that can be used
|
41 |
+
to generate pseudo-labels under a range of settings, including using both greedy and beam-search. It is also compatible
|
42 |
+
with [🤗 Datasets](https://github.com/huggingface/datasets) *streaming mode*, allowing users to load massive audio
|
43 |
+
datasets with **no disk space requirements**. For more information on streaming mode, the reader is referred to the
|
44 |
+
blog post: [A Complete Guide to Audio Datasets](https://huggingface.co/blog/audio-datasets#streaming-mode-the-silver-bullet).
|
45 |
+
|
46 |
+
The following script demonstrates how to pseudo-label the [LibriSpeech 960h](https://huggingface.co/datasets/librispeech_asr)
|
47 |
+
dataset with greedy sampling and streaming mode:
|
48 |
+
|
49 |
+
```bash
|
50 |
+
#!/usr/bin/env bash
|
51 |
+
|
52 |
+
python run_pseudo_labelling.py \
|
53 |
+
--model_name_or_path "openai/whisper-large-v2" \
|
54 |
+
--dataset_name "librispeech_asr" \
|
55 |
+
--dataset_config_name "all" \
|
56 |
+
--data_split_name "train.clean.100+train.clean.360+train.other.500" \
|
57 |
+
--text_column_name "text" \
|
58 |
+
--output_dir "./transcriptions" \
|
59 |
+
--per_device_eval_batch_size 16 \
|
60 |
+
--max_label_length 256 \
|
61 |
+
--dtype "bfloat16" \
|
62 |
+
--report_to "wandb" \
|
63 |
+
--dataloader_num_workers 16 \
|
64 |
+
--streaming \
|
65 |
+
--push_to_hub \
|
66 |
+
--generation_num_beams 1 # for greedy, set >1 for beam
|
67 |
+
|
68 |
+
```
|
69 |
+
|
70 |
+
The script will save the generated pseudo-labels alongside the file ids to the output directory `output_dir`. Adding the
|
71 |
+
`--push_to_hub` argument uploads the generated pseudo-labels to the Hugging Face Hub on save.
|
72 |
+
|
73 |
+
The directory [`pseudo_labelling_scripts`](pseudo_labelling_scripts) contains a collection of bash scripts for
|
74 |
+
pseudo-labelling all 10 audio datasets used in the project. The datasets with the Whisper generated transcriptions
|
75 |
+
can be found on the Hugging Face Hub under the [Distil Whisper organisation](https://huggingface.co/datasets?sort=trending&search=distil-whisper%2F).
|
76 |
+
They can be re-used should you wish to bypass the data labelling stage of the reproduction.
|
77 |
+
|
78 |
+
<!--- TODO(SG): Combine PS with source audio to create dataset --->
|
79 |
+
|
80 |
+
## Initialisation
|
81 |
+
|
82 |
+
The script [`create_student_model.py`](create_student_model.py) can be used to initialise a small student model
|
83 |
+
from a large teacher model. When initialising a student model with fewer layers than the teacher model, the student is
|
84 |
+
initialised by copying maximally spaced layers from the teacher, as per the [DistilBart](https://arxiv.org/abs/2010.13002)
|
85 |
+
recommendations.
|
86 |
+
|
87 |
+
The following command demonstrates how to initialise a student model from the [large-v2](https://huggingface.co/openai/whisper-large-v2)
|
88 |
+
checkpoint, with all 32 encoder layer and 2 decoder layers. The 2 student decoder layers are copied from teacher layers
|
89 |
+
1 and 32 respectively, as the maximally spaced layers.
|
90 |
+
|
91 |
+
```bash
|
92 |
+
#!/usr/bin/env bash
|
93 |
+
|
94 |
+
python create_student_model.py \
|
95 |
+
--teacher_checkpoint "openai/whisper-large-v2" \
|
96 |
+
--encoder_layers 32 \
|
97 |
+
--decoder_layers 2 \
|
98 |
+
--save_dir "./large-32-2" \
|
99 |
+
--push_to_hub
|
100 |
+
```
|
101 |
+
|
102 |
+
|
103 |
+
## Training
|
104 |
+
|
105 |
+
The script [`run_distillation.py`](run_distillation.py) is an end-to-end script for loading multiple
|
106 |
+
datasets, a student model, a teacher model, and performing teacher-student distillation. It uses the loss formulation
|
107 |
+
from [DistilBart](https://arxiv.org/abs/2010.13002), which is a combination of a cross-entropy, KL-divergence and
|
108 |
+
mean-square error (MSE) loss:
|
109 |
+
|
110 |
+
https://github.com/huggingface/distil-whisper/blob/4dd831543e6c40b1159f1ec951db7f4fe0e86850/run_distillation.py#L1725
|
111 |
+
|
112 |
+
The weight assigned to the MSE loss is configurable. The others are fixed to the values from the DistilBART paper.
|
113 |
+
|
114 |
+
The following command takes the LibriSpeech 960h dataset that was pseudo-labelled in the first stage and trains the
|
115 |
+
2-layer decoder model intialised in the previous step. Note that multiple training datasets and splits can be loaded
|
116 |
+
by separating the dataset arguments by `+` symbols. Thus, the script generalises to any number of training datasets.
|
117 |
+
|
118 |
+
```bash
|
119 |
+
#!/usr/bin/env bash
|
120 |
+
|
121 |
+
python3 run_distillation.py \
|
122 |
+
--model_name_or_path "./large-32-2" \
|
123 |
+
--teacher_model_name_or_path "openai/whisper-large-v2" \
|
124 |
+
--train_dataset_name "librispeech_asr+librispeech_asr+librispeech_asr" \
|
125 |
+
--train_dataset_config_name "all+all+all" \
|
126 |
+
--train_split_name "train.clean.100+train.clean.360+train.other.500" \
|
127 |
+
--train_dataset_samples "100+360+500" \
|
128 |
+
--eval_dataset_name "librispeech_asr" \
|
129 |
+
--eval_dataset_config_name "all" \
|
130 |
+
--eval_split_name "validation.clean" \
|
131 |
+
--eval_steps 5000 \
|
132 |
+
--save_steps 5000 \
|
133 |
+
--warmup_steps 500 \
|
134 |
+
--learning_rate 0.0001 \
|
135 |
+
--lr_scheduler_type "constant_with_warmup" \
|
136 |
+
--logging_steps 25 \
|
137 |
+
--save_total_limit 1 \
|
138 |
+
--max_steps 20000 \
|
139 |
+
--wer_threshold 10 \
|
140 |
+
--per_device_train_batch_size 64 \
|
141 |
+
--per_device_eval_batch_size 64 \
|
142 |
+
--dataloader_num_workers 16 \
|
143 |
+
--dtype "bfloat16" \
|
144 |
+
--output_dir "./" \
|
145 |
+
--do_train \
|
146 |
+
--do_eval \
|
147 |
+
--use_scan \
|
148 |
+
--gradient_checkpointing \
|
149 |
+
--overwrite_output_dir \
|
150 |
+
--predict_with_generate \
|
151 |
+
--freeze_encoder \
|
152 |
+
--streaming \
|
153 |
+
--use_auth_token \
|
154 |
+
--push_to_hub
|
155 |
+
|
156 |
+
```
|
157 |
+
|
158 |
+
The above training script will take approximately 20 hours to complete on a TPU v4-8 and yield a final WER of 2.3%.
|
159 |
+
|
160 |
+
Training logs will be reported to TensorBoard and WandB, provided the relevant packages are available. An example of a
|
161 |
+
saved checkpoint pushed to the Hugging Face Hub can be found here: [large-32-2](https://huggingface.co/distil-whisper/large-32-2).
|
162 |
+
|
163 |
+
There are a few noteworthy arguments that can be configured to give optimal training performance:
|
164 |
+
* `train_dataset_samples`: defines the number of training samples in each dataset. Used to calculate the sampling probabilities in the dataloader. A good starting point is setting the samples to the number of hours of audio data in each split. A more refined strategy is setting it to the number of training samples in each split, however this might require downloading the dataset offline to compute these statistics.
|
165 |
+
* `wer_threshold`: sets the WER threshold between the normalised pseudo-labels and normalised ground truth labels. Any samples with WER > `wer_threshold` are discarded from the training data. This is beneficial to avoid training the student model on pseudo-labels where Whisper hallucinated or got the predictions grossly wrong.
|
166 |
+
* `freeze_encoder`: whether to freeze the entire encoder of the student model during training. Beneficial when the student encoder is copied exactly from the teacher encoder. In this case, the encoder hidden-states from the teacher model are re-used for the student model. Stopping the gradient computation through the encoder and sharing the encoder hidden-states provides a significant memory saving, and can enable up to 2x batch sizes.
|
167 |
+
* `dtype`: data type (dtype) in which the model computation should be performed. Note that this only controls the dtype of the computations (forward and backward pass), and not the dtype of the parameters or optimiser states.
|
168 |
+
|
169 |
+
The Distil Whisper project extends the above script to train on a combined dataset formed from 12 open-source ASR datasets,
|
170 |
+
totalling 22k hours and over 50k speakers. Template scripts to run training on this composite dataset can be found
|
171 |
+
in the directory [`distillation_scripts`](distillation_scripts).
|
172 |
+
|
173 |
+
## Evaluation
|
174 |
+
|
175 |
+
There are two types of evaluation performed in Distil-Whisper:
|
176 |
+
1. Short form: evaluation on audio samples less than 30s in duration. Examples include typical ASR test sets, such as the LibriSpeech validation set.
|
177 |
+
2. Long form: evaluation on audio samples longer than 30s in duration. Examples include entire TED talks or earnings calls.
|
178 |
+
|
179 |
+
Both forms of evaluation are performed using the *word-error rate (WER)* metric.
|
180 |
+
|
181 |
+
### Short Form
|
182 |
+
|
183 |
+
The script [`run_eval.py`](run_eval.py) can be used to evaluate a trained student model over multiple validation sets.
|
184 |
+
The following example demonstrates how to evaluate the student model trained in the previous step on the LibriSpeech
|
185 |
+
`validation.clean` and `validation.other` dev sets. Again, it leverages streaming mode to bypass the need to download
|
186 |
+
the data offline:
|
187 |
+
|
188 |
+
```bash
|
189 |
+
#!/usr/bin/env bash
|
190 |
+
|
191 |
+
python run_eval.py \
|
192 |
+
--model_name_or_path "./large-32-2" \
|
193 |
+
--dataset_name "librispeech_asr+librispeech_asr" \
|
194 |
+
--dataset_config_name "all+all" \
|
195 |
+
--dataset_split_name "validation.clean+validation.other" \
|
196 |
+
--output_dir "./large-32-2" \
|
197 |
+
--per_device_eval_batch_size 64 \
|
198 |
+
--dtype "bfloat16" \
|
199 |
+
--dataloader_num_workers 16 \
|
200 |
+
--report_to "wandb" \
|
201 |
+
--streaming \
|
202 |
+
--predict_with_generate
|
203 |
+
|
204 |
+
```
|
205 |
+
|
206 |
+
### Long Form
|
207 |
+
|
208 |
+
Long form evaluation runs on the premise that a single long audio file can be *chunked* into smaller segments and
|
209 |
+
inferred in parallel. The resulting transcriptions are then joined at the boundaries to give the final text prediction.
|
210 |
+
A small overlap (or *stride*) is used between adjacent segments to ensure a continuous transcription across chunks.
|
211 |
+
|
212 |
+
This style of chunked inference is performed using the [`FlaxWhisperPipeline`](https://github.com/huggingface/distil-whisper/blob/6426022e3b3a0a498b4150a636b54e2e3898bf1a/distil_whisper/pipeline.py#L61)
|
213 |
+
class, which is heavily inspired from [Whisper JAX](https://github.com/sanchit-gandhi/whisper-jax/tree/main#pipeline-usage).
|
214 |
+
|
215 |
+
The script [`run_long_form_transcription.py`](run_long_form_transcription.py) can be used to evaluate the trained
|
216 |
+
student model on an arbitrary number of long-form evaluation sets. The following script demonstrates how to evaluate
|
217 |
+
the example student model on two such test sets, [Earnings 21](https://huggingface.co/datasets/distil-whisper/earnings21)
|
218 |
+
and [Earnings 22](https://huggingface.co/datasets/distil-whisper/earnings22):
|
219 |
+
|
220 |
+
```bash
|
221 |
+
#!/usr/bin/env bash
|
222 |
+
|
223 |
+
python run_long_form_transcription.py \
|
224 |
+
--model_name_or_path "./large-32-2" \
|
225 |
+
--dataset_name "distil-whisper/earnings21+distil-whisper/earnings22" \
|
226 |
+
--dataset_config_name "default+default" \
|
227 |
+
--dataset_split_name "test+test+test+test" \
|
228 |
+
--text_column_name "transcription+transcription" \
|
229 |
+
--output_dir "./large-32-2" \
|
230 |
+
--per_device_eval_batch_size 64 \
|
231 |
+
--chunk_length_s 15 \
|
232 |
+
--dtype "bfloat16" \
|
233 |
+
--report_to "wandb" \
|
234 |
+
--streaming
|
235 |
+
|
236 |
+
```
|
237 |
+
|
238 |
+
The argument `chunk_length_s` controls the length of the chunked audio samples. It should be set to match the typical
|
239 |
+
length of audio the student model was trained on. If unsure about what value of `chunk_length_s` is optimal for your case,
|
240 |
+
it is recommended to run a *sweep* over all possible values. A template script for running a [WandB sweep](https://docs.wandb.ai/guides/sweeps)
|
241 |
+
can be found under [`run_chunk_length_s_sweep.yaml`](long_form_transcription_scripts/run_chunk_length_s_sweep.yaml).
|
242 |
+
|
243 |
+
### 1. Pseudo Labelling
|
244 |
+
|
245 |
+
#### Greedy vs Beam
|
246 |
+
|
247 |
+
We found there to be little-to-no difference in the downstream performance of the distilled model after pseudo labelling
|
248 |
+
using either greedy or beam-search. We attribute this to the minimal difference in performance of the pre-trained Whisper
|
249 |
+
model under greedy and beam-search decoding, giving pseudo-labelled transcriptions of similar quality. We encourage
|
250 |
+
users to generate pseudo-labels using greedy decoding given it runs significantly faster. Beam search is only advised if
|
251 |
+
the pre-trained model is hallucinating significantly on the audio inputs, in which case it helps reduce the frequency and
|
252 |
+
severity of hallucinations. If using beam search, the number of beams can be kept low: even 2 beams helps reduce the
|
253 |
+
amount of hallucinations significantly.
|
254 |
+
|
255 |
+
#### Timestamps
|
256 |
+
|
257 |
+
Whisper is trained on a timestamp prediction task as part of the pre-training set-up. Here, a fixed proportion of the
|
258 |
+
pre-training data includes sequence-level *timestamps* as part of the transcription labels:
|
259 |
+
|
260 |
+
```bash
|
261 |
+
<|0.00|> Hey, this is a test transcription. <|3.42|>
|
262 |
+
```
|
263 |
+
|
264 |
+
Timestamp prediction is useful for enriching the transcriptions with timing information for downstream tasks, such as
|
265 |
+
aligning the Whisper transcription with the output of a speaker diarization system, and also reduces the frequency of
|
266 |
+
hallucinations.
|
267 |
+
|
268 |
+
The pseudo-labelling scrip [`run_pseudo_labelling.py`](run_pseudo_labelling.py) can be extended to predict timestamp
|
269 |
+
information in the audio data by appending the `--return_timestamps` flag to the launch command. The timestamped labelled
|
270 |
+
data can be passed to the training script in exactly the same way as the non-timestamped version, and the pre-processing
|
271 |
+
function will take care of encoding the timestamps and appending the required task tokens.
|
272 |
+
|
273 |
+
#### Previous Context
|
274 |
+
|
275 |
+
Whisper is also pre-trained on a prompting task, where the transcription for the preceding utterance is fed as context
|
276 |
+
to the current one:
|
277 |
+
|
278 |
+
```bash
|
279 |
+
<|startofprev|> This is the previous context from the preceding utterance.<|startoftranscript|> And this is the current utterance.<|endoftranscript|>
|
280 |
+
```
|
281 |
+
|
282 |
+
Annotating the transcriptions with previous context labels is only possible for datasets where we have consecutive files
|
283 |
+
and unique speaker ids, since we need to ensure segment `i` directly follows on from segment `i-1` if we use it as the
|
284 |
+
prompt.
|
285 |
+
|
286 |
+
As per the Whisper paper, we mask out the loss over the previous context tokens. At inference time, we can replace the
|
287 |
+
previous context with a “prompt” to encourage the model to generate text in the style of the prompt (i.e. for specific
|
288 |
+
named entities, or styles of transcription)
|
289 |
+
|
290 |
+
## Acknowledgements
|
291 |
+
|
292 |
+
* 🤗 Hugging Face Transformers for the base Whisper implementation
|
293 |
+
* Google's [TPU Research Cloud (TRC)](https://sites.research.google/trc/about/) programme for their generous provision of Cloud TPUs
|
flax/conversion_scripts/run_convert_distilled_train_state_to_hf.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python convert_train_state_to_hf.py \
|
4 |
+
--model_name_or_path "distil-whisper/large-32-2" \
|
5 |
+
--output_dir "./" \
|
6 |
+
--resume_from_checkpoint "checkpoint-15000" \
|
7 |
+
--cache_dir "/home/sanchitgandhi/.cache" \
|
8 |
+
--use_scan
|
flax/convert_train_state_to_hf.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Convert a Flax training state to HF Transformers Whisper weights.
|
18 |
+
"""
|
19 |
+
|
20 |
+
import logging
|
21 |
+
import os
|
22 |
+
import sys
|
23 |
+
from dataclasses import field
|
24 |
+
from pathlib import Path
|
25 |
+
from typing import Callable, Optional
|
26 |
+
|
27 |
+
import flax
|
28 |
+
import jax
|
29 |
+
import jax.numpy as jnp
|
30 |
+
import optax
|
31 |
+
from flax import jax_utils, traverse_util
|
32 |
+
from flax.serialization import from_bytes
|
33 |
+
from flax.training import train_state
|
34 |
+
from flax.training.common_utils import shard_prng_key
|
35 |
+
from huggingface_hub import Repository, create_repo
|
36 |
+
from optax._src import linear_algebra
|
37 |
+
from transformers import (
|
38 |
+
AutoConfig,
|
39 |
+
HfArgumentParser,
|
40 |
+
Seq2SeqTrainingArguments,
|
41 |
+
)
|
42 |
+
from transformers.file_utils import get_full_repo_name
|
43 |
+
from transformers.utils import check_min_version
|
44 |
+
from transformers.utils.versions import require_version
|
45 |
+
|
46 |
+
from distil_whisper import FlaxWhisperForConditionalGeneration
|
47 |
+
|
48 |
+
|
49 |
+
# initialise JAX for multi-host set-up on TPU
|
50 |
+
jax.distributed.initialize()
|
51 |
+
|
52 |
+
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
53 |
+
check_min_version("4.27.0.dev0")
|
54 |
+
|
55 |
+
require_version(
|
56 |
+
"datasets>=1.18.0",
|
57 |
+
"To fix: pip install -r examples/flax/speech-recogintion/requirements.txt",
|
58 |
+
)
|
59 |
+
|
60 |
+
logger = logging.getLogger(__name__)
|
61 |
+
|
62 |
+
|
63 |
+
@flax.struct.dataclass
|
64 |
+
class ModelArguments:
|
65 |
+
"""
|
66 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
67 |
+
"""
|
68 |
+
|
69 |
+
model_name_or_path: str = field(
|
70 |
+
metadata={"help": ("Path to pretrained student model or model identifier from huggingface.co/models")}
|
71 |
+
)
|
72 |
+
config_name: Optional[str] = field(
|
73 |
+
default=None,
|
74 |
+
metadata={"help": "Pretrained config name or path if not the same as model_name"},
|
75 |
+
)
|
76 |
+
cache_dir: Optional[str] = field(
|
77 |
+
default=None,
|
78 |
+
metadata={"help": ("Where to store the pretrained models downloaded from huggingface.co")},
|
79 |
+
)
|
80 |
+
use_fast_tokenizer: bool = field(
|
81 |
+
default=True,
|
82 |
+
metadata={"help": ("Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.")},
|
83 |
+
)
|
84 |
+
model_revision: str = field(
|
85 |
+
default="main",
|
86 |
+
metadata={"help": ("The specific model version to use (can be a branch name, tag name or commit id).")},
|
87 |
+
)
|
88 |
+
use_auth_token: bool = field(
|
89 |
+
default=False,
|
90 |
+
metadata={
|
91 |
+
"help": (
|
92 |
+
"Will use the token generated when running `transformers-cli login`"
|
93 |
+
" (necessary to use this script with private models)."
|
94 |
+
)
|
95 |
+
},
|
96 |
+
)
|
97 |
+
dtype: Optional[str] = field(
|
98 |
+
default="float32",
|
99 |
+
metadata={
|
100 |
+
"help": (
|
101 |
+
"Floating-point format in which the model weights should be initialized"
|
102 |
+
" and trained. Choose one of `[float32, float16, bfloat16]`."
|
103 |
+
)
|
104 |
+
},
|
105 |
+
)
|
106 |
+
load_with_scan_weights: bool = field(
|
107 |
+
default=False,
|
108 |
+
metadata={
|
109 |
+
"help": "Whether the pre-trained checkpoint has its weights stored in scan format. Set to True for scanned "
|
110 |
+
"weights, defaults to False for non-scan (unrolled) weights."
|
111 |
+
},
|
112 |
+
)
|
113 |
+
use_scan: bool = field(
|
114 |
+
default=True,
|
115 |
+
metadata={"help": ("Whether or not to use `scan_with_axes` over the encoder and decoder blocks.")},
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
def create_learning_rate_fn(
|
120 |
+
num_train_steps: int, lr_scheduler_type: str, num_warmup_steps: int, learning_rate: float
|
121 |
+
) -> Callable[[int], jnp.array]:
|
122 |
+
"""Returns a linear warmup, linear_decay learning rate function."""
|
123 |
+
lr_scheduler_types = ("linear", "constant_with_warmup")
|
124 |
+
|
125 |
+
if lr_scheduler_type not in lr_scheduler_types:
|
126 |
+
raise ValueError(
|
127 |
+
f"lr_scheduler_type of type {lr_scheduler_type} not supported, choose from {lr_scheduler_types}."
|
128 |
+
)
|
129 |
+
|
130 |
+
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
131 |
+
decay_fn = optax.linear_schedule(
|
132 |
+
init_value=learning_rate,
|
133 |
+
end_value=0 if lr_scheduler_type == "linear" else learning_rate,
|
134 |
+
transition_steps=num_train_steps - num_warmup_steps,
|
135 |
+
)
|
136 |
+
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
137 |
+
return schedule_fn
|
138 |
+
|
139 |
+
|
140 |
+
class TrainState(train_state.TrainState):
|
141 |
+
dropout_rng: jnp.ndarray
|
142 |
+
max_grad_norm: float
|
143 |
+
|
144 |
+
def apply_gradients(self, *, grads, **kwargs):
|
145 |
+
"""Updates `step`, `params`, `opt_state` and `**kwargs` in return value, clipping the
|
146 |
+
gradients by the maximum grad norm.
|
147 |
+
|
148 |
+
Note that internally this function calls `.tx.update()` followed by a call
|
149 |
+
to `optax.apply_updates()` to update `params` and `opt_state`.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
grads: Gradients that have the same pytree structure as `.params`.
|
153 |
+
**kwargs: Additional dataclass attributes that should be `.replace()`-ed.
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
An updated instance of `self` with `step` incremented by one, `params`
|
157 |
+
and `opt_state` updated by applying `grads`, and additional attributes
|
158 |
+
replaced as specified by `kwargs`.
|
159 |
+
"""
|
160 |
+
# clip gradients by global l2 norm
|
161 |
+
g_norm = linear_algebra.global_norm(grads)
|
162 |
+
g_norm = jnp.maximum(self.max_grad_norm, g_norm)
|
163 |
+
grads = jax.tree_map(lambda t: (t / g_norm) * self.max_grad_norm, grads)
|
164 |
+
|
165 |
+
updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params)
|
166 |
+
new_params = optax.apply_updates(self.params, updates)
|
167 |
+
|
168 |
+
return self.replace(
|
169 |
+
step=self.step + 1,
|
170 |
+
params=new_params,
|
171 |
+
opt_state=new_opt_state,
|
172 |
+
**kwargs,
|
173 |
+
)
|
174 |
+
|
175 |
+
def replicate(self):
|
176 |
+
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
177 |
+
|
178 |
+
def unreplicate(self):
|
179 |
+
return jax_utils.unreplicate(self)
|
180 |
+
|
181 |
+
|
182 |
+
def main():
|
183 |
+
# 1. Parse input arguments
|
184 |
+
# See all possible arguments in src/transformers/training_args.py
|
185 |
+
# or by passing the --help flag to this script.
|
186 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
187 |
+
parser = HfArgumentParser(
|
188 |
+
(
|
189 |
+
ModelArguments,
|
190 |
+
Seq2SeqTrainingArguments,
|
191 |
+
)
|
192 |
+
)
|
193 |
+
|
194 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
195 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
196 |
+
# let's parse it to get our arguments.
|
197 |
+
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
198 |
+
else:
|
199 |
+
model_args, training_args = parser.parse_args_into_dataclasses()
|
200 |
+
|
201 |
+
# Handle the repository creation
|
202 |
+
if training_args.push_to_hub:
|
203 |
+
if training_args.hub_model_id is None:
|
204 |
+
repo_name = get_full_repo_name(
|
205 |
+
Path(training_args.output_dir).absolute().name,
|
206 |
+
token=training_args.hub_token,
|
207 |
+
)
|
208 |
+
else:
|
209 |
+
repo_name = training_args.hub_model_id
|
210 |
+
create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
|
211 |
+
repo = Repository(
|
212 |
+
training_args.output_dir,
|
213 |
+
clone_from=repo_name,
|
214 |
+
token=training_args.hub_token,
|
215 |
+
)
|
216 |
+
|
217 |
+
# 5. Load pretrained config, model and processor
|
218 |
+
config = AutoConfig.from_pretrained(
|
219 |
+
(model_args.config_name if model_args.config_name else model_args.model_name_or_path),
|
220 |
+
cache_dir=model_args.cache_dir,
|
221 |
+
revision=model_args.model_revision,
|
222 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
223 |
+
)
|
224 |
+
student_model, student_params = FlaxWhisperForConditionalGeneration.from_pretrained(
|
225 |
+
model_args.model_name_or_path,
|
226 |
+
config=config,
|
227 |
+
dtype=getattr(jnp, model_args.dtype),
|
228 |
+
cache_dir=model_args.cache_dir,
|
229 |
+
revision=model_args.model_revision,
|
230 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
231 |
+
_do_init=False,
|
232 |
+
use_scan=model_args.load_with_scan_weights,
|
233 |
+
)
|
234 |
+
|
235 |
+
# enable scan / gradient checkpointing if necessary in the student model
|
236 |
+
if model_args.use_scan:
|
237 |
+
student_model.enable_scan() # to enable scan in the nn.Module
|
238 |
+
student_params = student_model.convert_unroll_to_scan(student_params) # to convert the unrolled params to scan
|
239 |
+
|
240 |
+
# Initialize our student state
|
241 |
+
rng = jax.random.PRNGKey(training_args.seed)
|
242 |
+
rng, dropout_rng = jax.random.split(rng)
|
243 |
+
|
244 |
+
total_train_steps = int(training_args.max_steps)
|
245 |
+
|
246 |
+
# Create learning rate schedule
|
247 |
+
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
248 |
+
total_train_steps,
|
249 |
+
training_args.lr_scheduler_type,
|
250 |
+
training_args.warmup_steps,
|
251 |
+
training_args.learning_rate,
|
252 |
+
)
|
253 |
+
|
254 |
+
# We use Optax's "masking" functionality to not apply weight decay
|
255 |
+
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
256 |
+
# mask boolean with the same structure as the parameters.
|
257 |
+
# The mask is True for parameters that should be decayed.
|
258 |
+
def decay_mask_fn(params):
|
259 |
+
flat_params = traverse_util.flatten_dict(params)
|
260 |
+
# find out all LayerNorm parameters
|
261 |
+
layer_norm_candidates = [
|
262 |
+
"layer_norm",
|
263 |
+
"self_attn_layer_norm",
|
264 |
+
"final_layer_norm",
|
265 |
+
"encoder_attn_layer_norm",
|
266 |
+
]
|
267 |
+
layer_norm_named_params = {
|
268 |
+
layer[-2:]
|
269 |
+
for layer_norm_name in layer_norm_candidates
|
270 |
+
for layer in flat_params.keys()
|
271 |
+
if layer_norm_name in "".join(layer).lower()
|
272 |
+
}
|
273 |
+
flat_mask = {path: path[-1] != "bias" and path[-2:] not in layer_norm_named_params for path in flat_params}
|
274 |
+
return traverse_util.unflatten_dict(flat_mask)
|
275 |
+
|
276 |
+
# create adam optimizer
|
277 |
+
adamw = optax.adamw(
|
278 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
279 |
+
b1=training_args.adam_beta1,
|
280 |
+
b2=training_args.adam_beta2,
|
281 |
+
eps=training_args.adam_epsilon,
|
282 |
+
weight_decay=training_args.weight_decay,
|
283 |
+
mask=decay_mask_fn,
|
284 |
+
)
|
285 |
+
|
286 |
+
# Setup train state
|
287 |
+
student_state = TrainState.create(
|
288 |
+
apply_fn=student_model.__call__,
|
289 |
+
params=student_params,
|
290 |
+
tx=adamw,
|
291 |
+
dropout_rng=dropout_rng,
|
292 |
+
max_grad_norm=training_args.max_grad_norm,
|
293 |
+
)
|
294 |
+
|
295 |
+
if training_args.resume_from_checkpoint is not None:
|
296 |
+
if os.path.isfile(os.path.join(training_args.resume_from_checkpoint, "train_state.msgpack")):
|
297 |
+
logger.info(
|
298 |
+
f"Checkpoint detected, resuming training at {training_args.resume_from_checkpoint}. To avoid "
|
299 |
+
"this behavior, omit the resume_from_checkpoint argument."
|
300 |
+
)
|
301 |
+
with Path(os.path.join(training_args.resume_from_checkpoint, "train_state.msgpack")).open("rb") as f:
|
302 |
+
student_state = from_bytes(student_state, f.read())
|
303 |
+
else:
|
304 |
+
logger.warning(
|
305 |
+
f"Checkpoint {training_args.resume_from_checkpoint} not detected, training from scratch. Ensure "
|
306 |
+
f"you pass the path to a folder with a valid checkpoint for your model."
|
307 |
+
)
|
308 |
+
|
309 |
+
cur_step = int(jax.device_get(student_state.step))
|
310 |
+
|
311 |
+
# save weights in HF Transformers format
|
312 |
+
if jax.process_index() == 0:
|
313 |
+
student_model.disable_scan()
|
314 |
+
student_state_params = student_model.convert_scan_to_unroll(student_state.params)
|
315 |
+
student_params = jax.device_get(student_state_params)
|
316 |
+
student_model.save_pretrained(
|
317 |
+
os.path.join(training_args.output_dir, f"checkpoint-{cur_step}"), params=student_params
|
318 |
+
)
|
319 |
+
if training_args.push_to_hub:
|
320 |
+
repo.push_to_hub(
|
321 |
+
commit_message=f"Saving weights of step {cur_step}",
|
322 |
+
blocking=False,
|
323 |
+
)
|
324 |
+
|
325 |
+
|
326 |
+
if __name__ == "__main__":
|
327 |
+
main()
|
flax/create_student_model.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Initialise a student Whisper model from a pre-trained teacher model for
|
18 |
+
teacher-student distillation.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import argparse
|
22 |
+
import copy
|
23 |
+
import logging
|
24 |
+
|
25 |
+
import jax
|
26 |
+
import numpy as np
|
27 |
+
from flax.core import freeze, unfreeze
|
28 |
+
from transformers import GenerationConfig, WhisperFeatureExtractor, WhisperProcessor
|
29 |
+
|
30 |
+
from distil_whisper import FlaxWhisperForConditionalGeneration
|
31 |
+
|
32 |
+
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
def parse_args():
|
37 |
+
parser = argparse.ArgumentParser(
|
38 |
+
description="Initialise a student Whisper model from a teacher model, copying the relevant layer weights and adjusting the processor as necessary."
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--teacher_checkpoint",
|
42 |
+
type=str,
|
43 |
+
required=True,
|
44 |
+
help="The HF Hub ID of the teacher checkpoint.",
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--subfolder",
|
48 |
+
type=str,
|
49 |
+
default="",
|
50 |
+
help="In case the relevant teacher weights are located inside a subfolder of the model repo on huggingface.co, you "
|
51 |
+
"can specify the folder name here.",
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--encoder_layers",
|
55 |
+
type=int,
|
56 |
+
default=None,
|
57 |
+
help="Number of encoder layers to use in the student model. Defaults to all layers from the teacher.",
|
58 |
+
)
|
59 |
+
parser.add_argument(
|
60 |
+
"--decoder_layers",
|
61 |
+
type=int,
|
62 |
+
default=2,
|
63 |
+
help="Number of decoder layers to use in the student model. Defaults to 2 layers.",
|
64 |
+
)
|
65 |
+
parser.add_argument(
|
66 |
+
"--max_source_positions",
|
67 |
+
type=int,
|
68 |
+
default=None,
|
69 |
+
help="The maximum sequence length of log-mel filter-bank features that this model might ever be used with. Can "
|
70 |
+
"be used to create a student model with a shorter context length than the teacher model. Defaults to the number "
|
71 |
+
"of source positions in the teacher model (1500).",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--save_dir",
|
75 |
+
type=str,
|
76 |
+
required=True,
|
77 |
+
help="Where to save the student weights and processor.",
|
78 |
+
)
|
79 |
+
parser.add_argument(
|
80 |
+
"--push_to_hub",
|
81 |
+
type=bool,
|
82 |
+
required=False,
|
83 |
+
default=False,
|
84 |
+
help="Whether to push the student weights and processor to the Hub.",
|
85 |
+
)
|
86 |
+
parser.add_argument(
|
87 |
+
"--cache_dir",
|
88 |
+
type=str,
|
89 |
+
default=None,
|
90 |
+
help="Where to store the pretrained models downloaded from huggingface.co",
|
91 |
+
)
|
92 |
+
|
93 |
+
args = parser.parse_args()
|
94 |
+
return args
|
95 |
+
|
96 |
+
|
97 |
+
def init_student_model_from_teacher(
|
98 |
+
teacher_checkpoint,
|
99 |
+
encoder_layers=None,
|
100 |
+
decoder_layers=2,
|
101 |
+
max_source_positions=None,
|
102 |
+
save_dir=None,
|
103 |
+
push_to_hub=None,
|
104 |
+
cache_dir=None,
|
105 |
+
subfolder="",
|
106 |
+
):
|
107 |
+
teacher_model, teacher_params = FlaxWhisperForConditionalGeneration.from_pretrained(
|
108 |
+
teacher_checkpoint,
|
109 |
+
_do_init=False,
|
110 |
+
cache_dir=cache_dir,
|
111 |
+
subfolder=subfolder,
|
112 |
+
)
|
113 |
+
processor = WhisperProcessor.from_pretrained(teacher_checkpoint)
|
114 |
+
generation_config = GenerationConfig.from_pretrained(teacher_checkpoint)
|
115 |
+
|
116 |
+
teacher_config = teacher_model.config
|
117 |
+
teacher_encoder_layers = teacher_config.encoder_layers
|
118 |
+
teacher_decoder_layers = teacher_config.decoder_layers
|
119 |
+
|
120 |
+
student_config = copy.deepcopy(teacher_config)
|
121 |
+
student_config.update(
|
122 |
+
{
|
123 |
+
"encoder_layers": encoder_layers if encoder_layers is not None else teacher_encoder_layers,
|
124 |
+
"decoder_layers": decoder_layers,
|
125 |
+
"max_source_positions": (
|
126 |
+
max_source_positions if max_source_positions is not None else student_config.max_source_positions
|
127 |
+
),
|
128 |
+
}
|
129 |
+
)
|
130 |
+
|
131 |
+
encoder_mapping = np.linspace(0, teacher_encoder_layers - 1, student_config.encoder_layers, dtype=int)
|
132 |
+
encoder_mapping[-1] = teacher_encoder_layers - 1
|
133 |
+
|
134 |
+
encoder_map = {}
|
135 |
+
for student_layer, teacher_layer in enumerate(encoder_mapping):
|
136 |
+
encoder_map[str(teacher_layer)] = str(student_layer)
|
137 |
+
|
138 |
+
decoder_mapping = np.linspace(0, teacher_decoder_layers - 1, student_config.decoder_layers, dtype=int)
|
139 |
+
decoder_mapping[-1] = teacher_decoder_layers - 1
|
140 |
+
|
141 |
+
decoder_map = {}
|
142 |
+
for student_layer, teacher_layer in enumerate(decoder_mapping):
|
143 |
+
decoder_map[str(teacher_layer)] = str(student_layer)
|
144 |
+
|
145 |
+
# init the student params from the teacher model
|
146 |
+
student_params = unfreeze(teacher_params)
|
147 |
+
student_params["model"]["decoder"]["layers"] = {}
|
148 |
+
|
149 |
+
for layer in teacher_params["model"]["decoder"]["layers"]:
|
150 |
+
if layer in decoder_map:
|
151 |
+
# re-introduce pre-defined layers from the teacher
|
152 |
+
student_params["model"]["decoder"]["layers"][decoder_map[layer]] = teacher_params["model"]["decoder"][
|
153 |
+
"layers"
|
154 |
+
][layer]
|
155 |
+
|
156 |
+
if encoder_layers is not None:
|
157 |
+
student_params["model"]["encoder"]["layers"] = {}
|
158 |
+
for layer in teacher_params["model"]["encoder"]["layers"]:
|
159 |
+
if layer in encoder_map:
|
160 |
+
# re-introduce pre-defined layers from the teacher
|
161 |
+
student_params["model"]["encoder"]["layers"][encoder_map[layer]] = teacher_params["model"]["encoder"][
|
162 |
+
"layers"
|
163 |
+
][layer]
|
164 |
+
|
165 |
+
if max_source_positions is not None:
|
166 |
+
# slice the first MAX_SOURCE_POSITIONS embedding weights
|
167 |
+
student_params["model"]["encoder"]["embed_positions"]["embedding"] = teacher_params["model"]["encoder"][
|
168 |
+
"embed_positions"
|
169 |
+
]["embedding"][: student_config.max_source_positions, :]
|
170 |
+
# update the feature extractor to handle the new input length
|
171 |
+
chunk_length = int(student_config.max_source_positions * 2 / 100)
|
172 |
+
processor.feature_extractor = WhisperFeatureExtractor(chunk_length=chunk_length)
|
173 |
+
|
174 |
+
# remove the teacher params and model
|
175 |
+
del teacher_params, teacher_model
|
176 |
+
|
177 |
+
# save the converted weights and model
|
178 |
+
student_params = freeze(student_params)
|
179 |
+
student_model = FlaxWhisperForConditionalGeneration(student_config, _do_init=False)
|
180 |
+
|
181 |
+
if save_dir is not None:
|
182 |
+
student_model.save_pretrained(save_dir, params=student_params)
|
183 |
+
# we also need to correctly save the processor and generation config
|
184 |
+
processor.save_pretrained(save_dir)
|
185 |
+
generation_config.save_pretrained(save_dir)
|
186 |
+
|
187 |
+
# check we can do a forward pass with the saved model - first load the weights and processor
|
188 |
+
logger.info("Checking we can load the saved model...")
|
189 |
+
student_model, student_params = FlaxWhisperForConditionalGeneration.from_pretrained(
|
190 |
+
save_dir,
|
191 |
+
_do_init=False,
|
192 |
+
)
|
193 |
+
processor = WhisperProcessor.from_pretrained(save_dir)
|
194 |
+
|
195 |
+
# define some random inputs
|
196 |
+
input_features = processor(np.ones(16000), sampling_rate=16000, return_tensors="np").input_features
|
197 |
+
decoder_start_token_id = student_model.config.decoder_start_token_id
|
198 |
+
decoder_input_ids = np.ones((input_features.shape[0], 1)) * decoder_start_token_id
|
199 |
+
|
200 |
+
# do a forward pass - outputs will be gibberish for the initialised model so we can't check them
|
201 |
+
logger.info("Checking we can run the converted model forward...")
|
202 |
+
_ = student_model(input_features, decoder_input_ids=decoder_input_ids, params=student_params).logits
|
203 |
+
logger.info("Conversion successful!")
|
204 |
+
|
205 |
+
if push_to_hub:
|
206 |
+
student_model.push_to_hub(save_dir, params=student_params)
|
207 |
+
processor.push_to_hub(save_dir)
|
208 |
+
generation_config.push_to_hub(save_dir)
|
209 |
+
|
210 |
+
|
211 |
+
if __name__ == "__main__":
|
212 |
+
args = parse_args()
|
213 |
+
|
214 |
+
# Set the verbosity to info of the logger - we only want one process per machine to log things on the screen
|
215 |
+
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
|
216 |
+
|
217 |
+
init_student_model_from_teacher(
|
218 |
+
teacher_checkpoint=args.teacher_checkpoint,
|
219 |
+
encoder_layers=args.encoder_layers,
|
220 |
+
decoder_layers=args.decoder_layers,
|
221 |
+
max_source_positions=args.max_source_positions,
|
222 |
+
save_dir=args.save_dir,
|
223 |
+
push_to_hub=args.push_to_hub,
|
224 |
+
cache_dir=args.cache_dir,
|
225 |
+
subfolder=args.subfolder,
|
226 |
+
)
|
flax/distil_whisper/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
__version__ = "0.0.1"
|
17 |
+
|
18 |
+
from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration
|
19 |
+
from .partitioner import PjitPartitioner
|
20 |
+
from .pipeline import FlaxWhisperPipeline
|
21 |
+
from .train_state import InferenceState
|
flax/distil_whisper/layers.py
ADDED
@@ -0,0 +1,1338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The T5X Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Dense attention classes and mask/weighting functions."""
|
16 |
+
|
17 |
+
# pylint: disable=attribute-defined-outside-init,g-bare-generic
|
18 |
+
|
19 |
+
import dataclasses
|
20 |
+
import functools
|
21 |
+
import operator
|
22 |
+
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union
|
23 |
+
|
24 |
+
import jax
|
25 |
+
import jax.numpy as jnp
|
26 |
+
import numpy as np
|
27 |
+
from flax import linen as nn
|
28 |
+
from flax.linen import partitioning as nn_partitioning
|
29 |
+
from flax.linen.dtypes import promote_dtype
|
30 |
+
from jax import lax, random
|
31 |
+
|
32 |
+
|
33 |
+
# from flax.linen.partitioning import param_with_axes, with_sharding_constraint
|
34 |
+
param_with_axes = nn_partitioning.param_with_axes
|
35 |
+
with_sharding_constraint = nn_partitioning.with_sharding_constraint
|
36 |
+
|
37 |
+
|
38 |
+
# Type annotations
|
39 |
+
Array = jnp.ndarray
|
40 |
+
DType = jnp.dtype
|
41 |
+
PRNGKey = jnp.ndarray
|
42 |
+
Shape = Iterable[int]
|
43 |
+
Activation = Callable[..., Array]
|
44 |
+
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]]
|
45 |
+
DotGeneralT = Callable[..., Array]
|
46 |
+
ConvGeneralDilatedT = Callable[..., Array]
|
47 |
+
PaddingLike = Union[str, int, Sequence[Union[int, Tuple[int, int]]]]
|
48 |
+
LaxPadding = Union[str, Sequence[Tuple[int, int]]]
|
49 |
+
|
50 |
+
# Parameter initializers.
|
51 |
+
Initializer = Callable[[PRNGKey, Shape, DType], Array]
|
52 |
+
InitializerAxis = Union[int, Tuple[int, ...]]
|
53 |
+
NdInitializer = Callable[[PRNGKey, Shape, DType, InitializerAxis, InitializerAxis], Array]
|
54 |
+
|
55 |
+
default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0)
|
56 |
+
|
57 |
+
|
58 |
+
# ------------------------------------------------------------------------------
|
59 |
+
# Temporary inlined JAX N-d initializer code
|
60 |
+
# TODO(levskaya): remove once new JAX release is out.
|
61 |
+
# ------------------------------------------------------------------------------
|
62 |
+
def _compute_fans(shape: jax.core.NamedShape, in_axis=-2, out_axis=-1):
|
63 |
+
"""Inlined JAX `nn.initializer._compute_fans`."""
|
64 |
+
if isinstance(in_axis, int):
|
65 |
+
in_size = shape[in_axis]
|
66 |
+
else:
|
67 |
+
in_size = int(np.prod([shape[i] for i in in_axis]))
|
68 |
+
if isinstance(out_axis, int):
|
69 |
+
out_size = shape[out_axis]
|
70 |
+
else:
|
71 |
+
out_size = int(np.prod([shape[i] for i in out_axis]))
|
72 |
+
receptive_field_size = shape.total / in_size / out_size
|
73 |
+
fan_in = in_size * receptive_field_size
|
74 |
+
fan_out = out_size * receptive_field_size
|
75 |
+
return fan_in, fan_out
|
76 |
+
|
77 |
+
|
78 |
+
def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=jnp.float_):
|
79 |
+
"""Inlined JAX `nn.initializer.variance_scaling`."""
|
80 |
+
|
81 |
+
def init(key, shape, dtype=dtype):
|
82 |
+
return jnp.zeros(shape, dtype=dtype)
|
83 |
+
dtype = jax.dtypes.canonicalize_dtype(dtype)
|
84 |
+
shape = jax.core.as_named_shape(shape)
|
85 |
+
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
|
86 |
+
if mode == "fan_in":
|
87 |
+
denominator = fan_in
|
88 |
+
elif mode == "fan_out":
|
89 |
+
denominator = fan_out
|
90 |
+
elif mode == "fan_avg":
|
91 |
+
denominator = (fan_in + fan_out) / 2
|
92 |
+
else:
|
93 |
+
raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
|
94 |
+
variance = jnp.array(scale / denominator, dtype=dtype)
|
95 |
+
|
96 |
+
if distribution == "truncated_normal":
|
97 |
+
# constant is stddev of standard normal truncated to (-2, 2)
|
98 |
+
stddev = jnp.sqrt(variance) / jnp.array(0.87962566103423978, dtype)
|
99 |
+
return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
|
100 |
+
elif distribution == "normal":
|
101 |
+
return random.normal(key, shape, dtype) * jnp.sqrt(variance)
|
102 |
+
elif distribution == "uniform":
|
103 |
+
return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance)
|
104 |
+
else:
|
105 |
+
raise ValueError("invalid distribution for variance scaling initializer: {}".format(distribution))
|
106 |
+
|
107 |
+
return init
|
108 |
+
|
109 |
+
|
110 |
+
# ------------------------------------------------------------------------------
|
111 |
+
|
112 |
+
|
113 |
+
def nd_dense_init(scale, mode, distribution):
|
114 |
+
"""Initializer with in_axis, out_axis set at call time."""
|
115 |
+
|
116 |
+
def init_fn(key, shape, dtype, in_axis, out_axis):
|
117 |
+
fn = variance_scaling(scale, mode, distribution, in_axis, out_axis)
|
118 |
+
return fn(key, shape, dtype)
|
119 |
+
|
120 |
+
return init_fn
|
121 |
+
|
122 |
+
|
123 |
+
def dot_product_attention(
|
124 |
+
query: Array,
|
125 |
+
key: Array,
|
126 |
+
value: Array,
|
127 |
+
bias: Optional[Array] = None,
|
128 |
+
dropout_rng: Optional[PRNGKey] = None,
|
129 |
+
dropout_rate: float = 0.0,
|
130 |
+
deterministic: bool = False,
|
131 |
+
dtype: DType = jnp.float32,
|
132 |
+
float32_logits: bool = False,
|
133 |
+
):
|
134 |
+
"""Computes dot-product attention given query, key, and value.
|
135 |
+
|
136 |
+
This is the core function for applying attention based on
|
137 |
+
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
|
138 |
+
query and key and combines the values using the attention weights.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
query: queries for calculating attention with shape of `[batch, q_length,
|
142 |
+
num_heads, qk_depth_per_head]`.
|
143 |
+
key: keys for calculating attention with shape of `[batch, kv_length,
|
144 |
+
num_heads, qk_depth_per_head]`.
|
145 |
+
value: values to be used in attention with shape of `[batch, kv_length,
|
146 |
+
num_heads, v_depth_per_head]`.
|
147 |
+
bias: bias for the attention weights. This should be broadcastable to the
|
148 |
+
shape `[batch, num_heads, q_length, kv_length]` This can be used for
|
149 |
+
incorporating causal masks, padding masks, proximity bias, etc.
|
150 |
+
dropout_rng: JAX PRNGKey: to be used for dropout
|
151 |
+
dropout_rate: dropout rate
|
152 |
+
deterministic: bool, deterministic or not (to apply dropout)
|
153 |
+
dtype: the dtype of the computation (default: float32)
|
154 |
+
float32_logits: bool, if True then compute logits in float32 to avoid
|
155 |
+
numerical issues with bfloat16.
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
Output of shape `[batch, length, num_heads, v_depth_per_head]`.
|
159 |
+
"""
|
160 |
+
assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
|
161 |
+
assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], "q, k, v batch dims must match."
|
162 |
+
assert query.shape[-2] == key.shape[-2] == value.shape[-2], "q, k, v num_heads must match."
|
163 |
+
assert key.shape[-3] == value.shape[-3], "k, v lengths must match."
|
164 |
+
assert query.shape[-1] == key.shape[-1], "q, k depths must match."
|
165 |
+
|
166 |
+
# Casting logits and softmax computation for float32 for model stability.
|
167 |
+
if float32_logits:
|
168 |
+
query = query.astype(jnp.float32)
|
169 |
+
key = key.astype(jnp.float32)
|
170 |
+
|
171 |
+
# `attn_weights`: [batch, num_heads, q_length, kv_length]
|
172 |
+
attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key)
|
173 |
+
|
174 |
+
# Apply attention bias: masking, dropout, proximity bias, etc.
|
175 |
+
if bias is not None:
|
176 |
+
attn_weights = attn_weights + bias.astype(attn_weights.dtype)
|
177 |
+
|
178 |
+
# Normalize the attention weights across `kv_length` dimension.
|
179 |
+
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
180 |
+
|
181 |
+
# Apply attention dropout.
|
182 |
+
if not deterministic and dropout_rate > 0.0:
|
183 |
+
keep_prob = 1.0 - dropout_rate
|
184 |
+
# T5 broadcasts along the "length" dim, but unclear which one that
|
185 |
+
# corresponds to in positional dimensions here, assuming query dim.
|
186 |
+
dropout_shape = list(attn_weights.shape)
|
187 |
+
dropout_shape[-2] = 1
|
188 |
+
keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
|
189 |
+
keep = jnp.broadcast_to(keep, attn_weights.shape)
|
190 |
+
multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)
|
191 |
+
attn_weights = attn_weights * multiplier
|
192 |
+
|
193 |
+
# Take the linear combination of `value`.
|
194 |
+
return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
|
195 |
+
|
196 |
+
|
197 |
+
dynamic_vector_slice_in_dim = jax.vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
|
198 |
+
|
199 |
+
|
200 |
+
class MultiHeadDotProductAttention(nn.Module):
|
201 |
+
"""Multi-head dot-product attention.
|
202 |
+
|
203 |
+
Attributes:
|
204 |
+
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
|
205 |
+
should be divisible by the number of heads.
|
206 |
+
head_dim: dimension of each head.
|
207 |
+
dtype: the dtype of the computation.
|
208 |
+
dropout_rate: dropout rate
|
209 |
+
kernel_init: initializer for the kernel of the Dense layers.
|
210 |
+
float32_logits: bool, if True then compute logits in float32 to avoid
|
211 |
+
numerical issues with bfloat16.
|
212 |
+
"""
|
213 |
+
|
214 |
+
num_heads: int
|
215 |
+
head_dim: int
|
216 |
+
dtype: DType = jnp.float32
|
217 |
+
dropout_rate: float = 0.0
|
218 |
+
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal")
|
219 |
+
float32_logits: bool = False # computes logits in float32 for stability.
|
220 |
+
|
221 |
+
@nn.compact
|
222 |
+
def __call__(
|
223 |
+
self,
|
224 |
+
inputs_q: Array,
|
225 |
+
inputs_kv: Array,
|
226 |
+
mask: Optional[Array] = None,
|
227 |
+
bias: Optional[Array] = None,
|
228 |
+
*,
|
229 |
+
decode: bool = False,
|
230 |
+
deterministic: bool = False,
|
231 |
+
) -> Array:
|
232 |
+
"""Applies multi-head dot product attention on the input data.
|
233 |
+
|
234 |
+
Projects the inputs into multi-headed query, key, and value vectors,
|
235 |
+
applies dot-product attention and project the results to an output vector.
|
236 |
+
|
237 |
+
There are two modes: decoding and non-decoding (e.g., training). The mode is
|
238 |
+
determined by `decode` argument. For decoding, this method is called twice,
|
239 |
+
first to initialize the cache and then for an actual decoding process. The
|
240 |
+
two calls are differentiated by the presence of 'cached_key' in the variable
|
241 |
+
dict. In the cache initialization stage, the cache variables are initialized
|
242 |
+
as zeros and will be filled in the subsequent decoding process.
|
243 |
+
|
244 |
+
In the cache initialization call, `inputs_q` has a shape [batch, length,
|
245 |
+
q_features] and `inputs_kv`: [batch, length, kv_features]. During the
|
246 |
+
incremental decoding stage, query, key and value all have the shape [batch,
|
247 |
+
1, qkv_features] corresponding to a single step.
|
248 |
+
|
249 |
+
Args:
|
250 |
+
inputs_q: input queries of shape `[batch, q_length, q_features]`.
|
251 |
+
inputs_kv: key/values of shape `[batch, kv_length, kv_features]`.
|
252 |
+
mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`.
|
253 |
+
bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`.
|
254 |
+
decode: Whether to prepare and use an autoregressive cache.
|
255 |
+
deterministic: Disables dropout if set to True.
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
output of shape `[batch, length, q_features]`.
|
259 |
+
"""
|
260 |
+
projection = functools.partial(
|
261 |
+
DenseGeneral,
|
262 |
+
axis=-1,
|
263 |
+
features=(self.num_heads, self.head_dim),
|
264 |
+
kernel_axes=("embed", "heads", "kv"),
|
265 |
+
dtype=self.dtype,
|
266 |
+
)
|
267 |
+
|
268 |
+
# NOTE: T5 does not explicitly rescale the attention logits by
|
269 |
+
# 1/sqrt(depth_kq)! This is folded into the initializers of the
|
270 |
+
# linear transformations, which is equivalent under Adafactor.
|
271 |
+
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
|
272 |
+
|
273 |
+
def query_init(*args):
|
274 |
+
return self.kernel_init(*args) / depth_scaling
|
275 |
+
|
276 |
+
# Project inputs_q to multi-headed q/k/v
|
277 |
+
# dimensions are then [batch, length, num_heads, head_dim]
|
278 |
+
query = projection(kernel_init=query_init, name="query")(inputs_q)
|
279 |
+
key = projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
|
280 |
+
value = projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
|
281 |
+
|
282 |
+
query = with_sharding_constraint(query, ("batch", "length", "heads", "kv"))
|
283 |
+
key = with_sharding_constraint(key, ("batch", "length", "heads", "kv"))
|
284 |
+
value = with_sharding_constraint(value, ("batch", "length", "heads", "kv"))
|
285 |
+
|
286 |
+
if decode:
|
287 |
+
# Detect if we're initializing by absence of existing cache data.
|
288 |
+
is_initialized = self.has_variable("cache", "cached_key")
|
289 |
+
|
290 |
+
# The key and value have dimension [batch, length, num_heads, head_dim],
|
291 |
+
# but we cache them as [batch, num_heads, head_dim, length] as a TPU
|
292 |
+
# fusion optimization. This also enables the "scatter via one-hot
|
293 |
+
# broadcast" trick, which means we do a one-hot broadcast instead of a
|
294 |
+
# scatter/gather operations, resulting in a 3-4x speedup in practice.
|
295 |
+
def swap_dims(x):
|
296 |
+
return x[:-3] + tuple(x[i] for i in [-2, -1, -3])
|
297 |
+
|
298 |
+
cached_key = self.variable("cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype)
|
299 |
+
cached_value = self.variable("cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype)
|
300 |
+
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
301 |
+
if is_initialized:
|
302 |
+
batch, num_heads, head_dim, length = cached_key.value.shape
|
303 |
+
# During fast autoregressive decoding, we feed one position at a time,
|
304 |
+
# and cache the keys and values step by step.
|
305 |
+
# Sanity shape check of cached key against input query.
|
306 |
+
expected_shape = (batch, 1, num_heads, head_dim)
|
307 |
+
if expected_shape != query.shape:
|
308 |
+
raise ValueError(
|
309 |
+
"Autoregressive cache shape error, "
|
310 |
+
"expected query shape %s instead got %s." % (expected_shape, query.shape)
|
311 |
+
)
|
312 |
+
|
313 |
+
# Create a OHE of the current index. NOTE: the index is increased below.
|
314 |
+
cur_index = cache_index.value
|
315 |
+
one_hot_indices = jax.nn.one_hot(cur_index, length, dtype=key.dtype)
|
316 |
+
# In order to update the key, value caches with the current key and
|
317 |
+
# value, we move the length axis to the back, similar to what we did for
|
318 |
+
# the cached ones above.
|
319 |
+
# Note these are currently the key and value of a single position, since
|
320 |
+
# we feed one position at a time.
|
321 |
+
one_token_key = jnp.moveaxis(key, -3, -1)
|
322 |
+
one_token_value = jnp.moveaxis(value, -3, -1)
|
323 |
+
# Update key, value caches with our new 1d spatial slices.
|
324 |
+
# We implement an efficient scatter into the cache via one-hot
|
325 |
+
# broadcast and addition.
|
326 |
+
key = cached_key.value + one_token_key * one_hot_indices
|
327 |
+
value = cached_value.value + one_token_value * one_hot_indices
|
328 |
+
cached_key.value = key
|
329 |
+
cached_value.value = value
|
330 |
+
cache_index.value = cache_index.value + 1
|
331 |
+
# Move the keys and values back to their original shapes.
|
332 |
+
key = jnp.moveaxis(key, -1, -3)
|
333 |
+
value = jnp.moveaxis(value, -1, -3)
|
334 |
+
|
335 |
+
# Causal mask for cached decoder self-attention: our single query
|
336 |
+
# position should only attend to those key positions that have already
|
337 |
+
# been generated and cached, not the remaining zero elements.
|
338 |
+
mask = combine_masks(
|
339 |
+
mask,
|
340 |
+
jnp.broadcast_to(
|
341 |
+
jnp.arange(length) <= cur_index,
|
342 |
+
# (1, 1, length) represent (head dim, query length, key length)
|
343 |
+
# query length is 1 because during decoding we deal with one
|
344 |
+
# index.
|
345 |
+
# The same mask is applied to all batch elements and heads.
|
346 |
+
(batch, 1, 1, length),
|
347 |
+
),
|
348 |
+
)
|
349 |
+
|
350 |
+
# Grab the correct relative attention bias during decoding. This is
|
351 |
+
# only required during single step decoding.
|
352 |
+
if bias is not None:
|
353 |
+
# The bias is a full attention matrix, but during decoding we only
|
354 |
+
# have to take a slice of it.
|
355 |
+
# This is equivalent to bias[..., cur_index:cur_index+1, :].
|
356 |
+
bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2)
|
357 |
+
|
358 |
+
# Convert the boolean attention mask to an attention bias.
|
359 |
+
if mask is not None:
|
360 |
+
# attention mask in the form of attention bias
|
361 |
+
attention_bias = lax.select(
|
362 |
+
mask > 0,
|
363 |
+
jnp.full(mask.shape, 0.0).astype(self.dtype),
|
364 |
+
jnp.full(mask.shape, -1e10).astype(self.dtype),
|
365 |
+
)
|
366 |
+
else:
|
367 |
+
attention_bias = None
|
368 |
+
|
369 |
+
# Add provided bias term (e.g. relative position embedding).
|
370 |
+
if bias is not None:
|
371 |
+
attention_bias = combine_biases(attention_bias, bias)
|
372 |
+
|
373 |
+
dropout_rng = None
|
374 |
+
if not deterministic and self.dropout_rate > 0.0:
|
375 |
+
dropout_rng = self.make_rng("dropout")
|
376 |
+
|
377 |
+
# Apply attention.
|
378 |
+
x = dot_product_attention(
|
379 |
+
query,
|
380 |
+
key,
|
381 |
+
value,
|
382 |
+
bias=attention_bias,
|
383 |
+
dropout_rng=dropout_rng,
|
384 |
+
dropout_rate=self.dropout_rate,
|
385 |
+
deterministic=deterministic,
|
386 |
+
dtype=self.dtype,
|
387 |
+
float32_logits=self.float32_logits,
|
388 |
+
)
|
389 |
+
|
390 |
+
# Back to the original inputs dimensions.
|
391 |
+
out = DenseGeneral(
|
392 |
+
features=inputs_q.shape[-1], # output dim is set to the input dim.
|
393 |
+
axis=(-2, -1),
|
394 |
+
kernel_init=self.kernel_init,
|
395 |
+
kernel_axes=("heads", "kv", "embed"),
|
396 |
+
dtype=self.dtype,
|
397 |
+
name="out",
|
398 |
+
)(x)
|
399 |
+
return out
|
400 |
+
|
401 |
+
|
402 |
+
def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
|
403 |
+
# A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
|
404 |
+
return tuple([ax if ax >= 0 else ndim + ax for ax in axes])
|
405 |
+
|
406 |
+
|
407 |
+
def _canonicalize_tuple(x):
|
408 |
+
if isinstance(x, Iterable):
|
409 |
+
return tuple(x)
|
410 |
+
else:
|
411 |
+
return (x,)
|
412 |
+
|
413 |
+
|
414 |
+
# ------------------------------------------------------------------------------
|
415 |
+
# DenseGeneral for attention layers.
|
416 |
+
# ------------------------------------------------------------------------------
|
417 |
+
class DenseGeneral(nn.Module):
|
418 |
+
"""A linear transformation (without bias) with flexible axes.
|
419 |
+
|
420 |
+
Attributes:
|
421 |
+
features: tuple with numbers of output features.
|
422 |
+
axis: tuple with axes to apply the transformation on.
|
423 |
+
dtype: the dtype of the computation (default: float32).
|
424 |
+
kernel_init: initializer function for the weight matrix.
|
425 |
+
"""
|
426 |
+
|
427 |
+
features: Union[Iterable[int], int]
|
428 |
+
axis: Union[Iterable[int], int] = -1
|
429 |
+
dtype: DType = jnp.float32
|
430 |
+
params_dtype: DType = jnp.float32
|
431 |
+
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal")
|
432 |
+
kernel_axes: Tuple[str, ...] = ()
|
433 |
+
use_bias: bool = True
|
434 |
+
bias_init: Any = nn.initializers.zeros
|
435 |
+
|
436 |
+
@nn.compact
|
437 |
+
def __call__(self, inputs: Array) -> Array:
|
438 |
+
"""Applies a linear transformation to the inputs along multiple dimensions.
|
439 |
+
|
440 |
+
Args:
|
441 |
+
inputs: The nd-array to be transformed.
|
442 |
+
|
443 |
+
Returns:
|
444 |
+
The transformed input.
|
445 |
+
"""
|
446 |
+
features = _canonicalize_tuple(self.features)
|
447 |
+
axis = _canonicalize_tuple(self.axis)
|
448 |
+
|
449 |
+
inputs = jnp.asarray(inputs, self.dtype)
|
450 |
+
axis = _normalize_axes(axis, inputs.ndim)
|
451 |
+
|
452 |
+
kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features
|
453 |
+
kernel_in_axis = np.arange(len(axis))
|
454 |
+
kernel_out_axis = np.arange(len(axis), len(axis) + len(features))
|
455 |
+
kernel = param_with_axes(
|
456 |
+
"kernel",
|
457 |
+
self.kernel_init,
|
458 |
+
kernel_shape,
|
459 |
+
self.params_dtype,
|
460 |
+
kernel_in_axis,
|
461 |
+
kernel_out_axis,
|
462 |
+
axes=self.kernel_axes,
|
463 |
+
)
|
464 |
+
if self.use_bias:
|
465 |
+
bias = param_with_axes(
|
466 |
+
"bias",
|
467 |
+
self.bias_init,
|
468 |
+
features,
|
469 |
+
self.params_dtype,
|
470 |
+
axes=(self.kernel_axes[-1],),
|
471 |
+
)
|
472 |
+
kernel = jnp.asarray(kernel, self.dtype)
|
473 |
+
|
474 |
+
contract_ind = tuple(range(0, len(axis)))
|
475 |
+
y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
|
476 |
+
if self.use_bias:
|
477 |
+
bias = jnp.asarray(bias, self.dtype)
|
478 |
+
# y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
|
479 |
+
y += jnp.reshape(bias, (1,) * (len(features) - y.ndim) + bias.shape[:])
|
480 |
+
return y
|
481 |
+
|
482 |
+
|
483 |
+
def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
|
484 |
+
"""Convert a string to an activation function."""
|
485 |
+
if fn_or_string == "linear":
|
486 |
+
return lambda x: x
|
487 |
+
elif isinstance(fn_or_string, str):
|
488 |
+
return getattr(nn, fn_or_string)
|
489 |
+
elif callable(fn_or_string):
|
490 |
+
return fn_or_string
|
491 |
+
else:
|
492 |
+
raise ValueError("don't know how to convert %s to an activation function" % (fn_or_string,))
|
493 |
+
|
494 |
+
|
495 |
+
class MlpBlock(nn.Module):
|
496 |
+
"""Transformer MLP / feed-forward block.
|
497 |
+
|
498 |
+
Attributes:
|
499 |
+
intermediate_dim: Shared dimension of hidden layers.
|
500 |
+
activations: Type of activations for each layer. Each element is either
|
501 |
+
'linear', a string function name in flax.linen, or a function.
|
502 |
+
kernel_init: Kernel function, passed to the dense layers.
|
503 |
+
deterministic: Whether the dropout layers should be deterministic.
|
504 |
+
intermediate_dropout_rate: Dropout rate used after the intermediate layers.
|
505 |
+
dtype: Type for the dense layer.
|
506 |
+
"""
|
507 |
+
|
508 |
+
intermediate_dim: int = 2048
|
509 |
+
activations: Sequence[Union[str, Callable]] = ("relu",)
|
510 |
+
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal")
|
511 |
+
intermediate_dropout_rate: float = 0.1
|
512 |
+
dtype: Any = jnp.float32
|
513 |
+
|
514 |
+
@nn.compact
|
515 |
+
def __call__(self, inputs, decode: bool = False, deterministic: bool = False):
|
516 |
+
"""Applies Transformer MlpBlock module."""
|
517 |
+
# Iterate over specified MLP input activation functions.
|
518 |
+
# e.g. ('relu',) or ('gelu', 'linear') for gated-gelu.
|
519 |
+
activations = []
|
520 |
+
for idx, act_fn in enumerate(self.activations):
|
521 |
+
dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}"
|
522 |
+
x = DenseGeneral(
|
523 |
+
self.intermediate_dim,
|
524 |
+
dtype=self.dtype,
|
525 |
+
kernel_init=self.kernel_init,
|
526 |
+
kernel_axes=("embed", "mlp"),
|
527 |
+
name=dense_name,
|
528 |
+
)(inputs)
|
529 |
+
x = _convert_to_activation_function(act_fn)(x)
|
530 |
+
activations.append(x)
|
531 |
+
|
532 |
+
# Take elementwise product of above intermediate activations.
|
533 |
+
x = functools.reduce(operator.mul, activations)
|
534 |
+
# Apply dropout and final dense output projection.
|
535 |
+
x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))(
|
536 |
+
x, deterministic=deterministic
|
537 |
+
) # Broadcast along length.
|
538 |
+
x = with_sharding_constraint(x, ("batch", "length", "mlp"))
|
539 |
+
output = DenseGeneral(
|
540 |
+
inputs.shape[-1],
|
541 |
+
dtype=self.dtype,
|
542 |
+
kernel_init=self.kernel_init,
|
543 |
+
kernel_axes=("mlp", "embed"),
|
544 |
+
name="wo",
|
545 |
+
)(x)
|
546 |
+
return output
|
547 |
+
|
548 |
+
|
549 |
+
class Embed(nn.Module):
|
550 |
+
"""A parameterized function from integers [0, n) to d-dimensional vectors.
|
551 |
+
|
552 |
+
Attributes:
|
553 |
+
num_embeddings: number of embeddings.
|
554 |
+
features: number of feature dimensions for each embedding.
|
555 |
+
dtype: the dtype of the embedding vectors (default: float32).
|
556 |
+
embedding_init: embedding initializer.
|
557 |
+
one_hot: performs the gather with a one-hot contraction rather than a true
|
558 |
+
gather. This is currently needed for SPMD partitioning.
|
559 |
+
"""
|
560 |
+
|
561 |
+
num_embeddings: int
|
562 |
+
features: int
|
563 |
+
cast_input_dtype: Optional[DType] = None
|
564 |
+
dtype: DType = jnp.float32
|
565 |
+
params_dtype: DType = jnp.float32
|
566 |
+
attend_dtype: Optional[DType] = None
|
567 |
+
embedding_init: Initializer = default_embed_init
|
568 |
+
one_hot: bool = True
|
569 |
+
embedding: Array = dataclasses.field(init=False)
|
570 |
+
|
571 |
+
def setup(self):
|
572 |
+
self.embedding = param_with_axes(
|
573 |
+
"embedding",
|
574 |
+
self.embedding_init,
|
575 |
+
(self.num_embeddings, self.features),
|
576 |
+
self.params_dtype,
|
577 |
+
axes=("vocab", "embed"),
|
578 |
+
)
|
579 |
+
|
580 |
+
def __call__(self, inputs: Array) -> Array:
|
581 |
+
"""Embeds the inputs along the last dimension.
|
582 |
+
|
583 |
+
Args:
|
584 |
+
inputs: input data, all dimensions are considered batch dimensions.
|
585 |
+
|
586 |
+
Returns:
|
587 |
+
Output which is embedded input data. The output shape follows the input,
|
588 |
+
with an additional `features` dimension appended.
|
589 |
+
"""
|
590 |
+
if self.cast_input_dtype:
|
591 |
+
inputs = inputs.astype(self.cast_input_dtype)
|
592 |
+
if not jnp.issubdtype(inputs.dtype, jnp.integer):
|
593 |
+
raise ValueError("Input type must be an integer or unsigned integer.")
|
594 |
+
if self.one_hot:
|
595 |
+
iota = lax.iota(jnp.int32, self.num_embeddings)
|
596 |
+
one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype)
|
597 |
+
output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype))
|
598 |
+
else:
|
599 |
+
output = jnp.asarray(self.embedding, self.dtype)[inputs]
|
600 |
+
output = with_sharding_constraint(output, ("batch", "length", "embed"))
|
601 |
+
return output
|
602 |
+
|
603 |
+
def attend(self, query: Array) -> Array:
|
604 |
+
"""Attend over the embedding using a query array.
|
605 |
+
|
606 |
+
Args:
|
607 |
+
query: array with last dimension equal the feature depth `features` of the
|
608 |
+
embedding.
|
609 |
+
|
610 |
+
Returns:
|
611 |
+
An array with final dim `num_embeddings` corresponding to the batched
|
612 |
+
inner-product of the array of query vectors against each embedding.
|
613 |
+
Commonly used for weight-sharing between embeddings and logit transform
|
614 |
+
in NLP models.
|
615 |
+
"""
|
616 |
+
dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype
|
617 |
+
return jnp.dot(query, jnp.asarray(self.embedding, dtype).T)
|
618 |
+
|
619 |
+
|
620 |
+
class RelativePositionBiases(nn.Module):
|
621 |
+
"""Adds T5-style relative positional embeddings to the attention logits.
|
622 |
+
|
623 |
+
Attributes:
|
624 |
+
num_buckets: Number of buckets to bucket distances between key and query
|
625 |
+
positions into.
|
626 |
+
max_distance: Maximum distance before everything is lumped into the last
|
627 |
+
distance bucket.
|
628 |
+
num_heads: Number of heads in the attention layer. Each head will get a
|
629 |
+
different relative position weighting.
|
630 |
+
dtype: Type of arrays through this module.
|
631 |
+
embedding_init: initializer for relative embedding table.
|
632 |
+
"""
|
633 |
+
|
634 |
+
num_buckets: int
|
635 |
+
max_distance: int
|
636 |
+
num_heads: int
|
637 |
+
dtype: Any
|
638 |
+
embedding_init: Callable[..., Array] = nn.linear.default_embed_init
|
639 |
+
|
640 |
+
@staticmethod
|
641 |
+
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
642 |
+
"""Translate relative position to a bucket number for relative attention.
|
643 |
+
|
644 |
+
The relative position is defined as memory_position - query_position, i.e.
|
645 |
+
the distance in tokens from the attending position to the attended-to
|
646 |
+
position. If bidirectional=False, then positive relative positions are
|
647 |
+
invalid.
|
648 |
+
We use smaller buckets for small absolute relative_position and larger
|
649 |
+
buckets for larger absolute relative_positions. All relative
|
650 |
+
positions >=max_distance map to the same bucket. All relative
|
651 |
+
positions <=-max_distance map to the same bucket. This should allow for
|
652 |
+
more graceful generalization to longer sequences than the model has been
|
653 |
+
trained on.
|
654 |
+
|
655 |
+
Args:
|
656 |
+
relative_position: an int32 array
|
657 |
+
bidirectional: a boolean - whether the attention is bidirectional
|
658 |
+
num_buckets: an integer
|
659 |
+
max_distance: an integer
|
660 |
+
|
661 |
+
Returns:
|
662 |
+
a Tensor with the same shape as relative_position, containing int32
|
663 |
+
values in the range [0, num_buckets)
|
664 |
+
"""
|
665 |
+
ret = 0
|
666 |
+
n = -relative_position
|
667 |
+
if bidirectional:
|
668 |
+
num_buckets //= 2
|
669 |
+
ret += (n < 0).astype(np.int32) * num_buckets
|
670 |
+
n = np.abs(n)
|
671 |
+
else:
|
672 |
+
n = np.maximum(n, 0)
|
673 |
+
# now n is in the range [0, inf)
|
674 |
+
max_exact = num_buckets // 2
|
675 |
+
is_small = n < max_exact
|
676 |
+
val_if_large = max_exact + (
|
677 |
+
np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps)
|
678 |
+
/ np.log(max_distance / max_exact)
|
679 |
+
* (num_buckets - max_exact)
|
680 |
+
).astype(np.int32)
|
681 |
+
val_if_large = np.minimum(val_if_large, num_buckets - 1)
|
682 |
+
ret += np.where(is_small, n, val_if_large)
|
683 |
+
return ret
|
684 |
+
|
685 |
+
@nn.compact
|
686 |
+
def __call__(self, qlen, klen, bidirectional=True):
|
687 |
+
"""Produce relative position embedding attention biases.
|
688 |
+
|
689 |
+
Args:
|
690 |
+
qlen: attention query length.
|
691 |
+
klen: attention key length.
|
692 |
+
bidirectional: whether to allow positive memory-query relative position
|
693 |
+
embeddings.
|
694 |
+
|
695 |
+
Returns:
|
696 |
+
output: `(1, len, q_len, k_len)` attention bias
|
697 |
+
"""
|
698 |
+
# TODO(levskaya): should we be computing this w. numpy as a program
|
699 |
+
# constant?
|
700 |
+
context_position = np.arange(qlen, dtype=jnp.int32)[:, None]
|
701 |
+
memory_position = np.arange(klen, dtype=jnp.int32)[None, :]
|
702 |
+
relative_position = memory_position - context_position # shape (qlen, klen)
|
703 |
+
rp_bucket = self._relative_position_bucket(
|
704 |
+
relative_position,
|
705 |
+
bidirectional=bidirectional,
|
706 |
+
num_buckets=self.num_buckets,
|
707 |
+
max_distance=self.max_distance,
|
708 |
+
)
|
709 |
+
relative_attention_bias = param_with_axes(
|
710 |
+
"rel_embedding",
|
711 |
+
self.embedding_init,
|
712 |
+
(self.num_heads, self.num_buckets),
|
713 |
+
jnp.float32,
|
714 |
+
axes=("heads", "relpos_buckets"),
|
715 |
+
)
|
716 |
+
|
717 |
+
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
|
718 |
+
# Instead of using a slow gather, we create a leading-dimension one-hot
|
719 |
+
# array from rp_bucket and use it to perform the gather-equivalent via a
|
720 |
+
# contraction, i.e.:
|
721 |
+
# (num_head, num_buckets) x (num_buckets one-hot, qlen, klen).
|
722 |
+
# This is equivalent to relative_attention_bias[:, rp_bucket]
|
723 |
+
bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
|
724 |
+
rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)
|
725 |
+
# --> shape (qlen, klen, num_heads)
|
726 |
+
values = lax.dot_general(
|
727 |
+
relative_attention_bias,
|
728 |
+
rp_bucket_one_hot,
|
729 |
+
(((1,), (0,)), ((), ())), # rhs, lhs contracting dims
|
730 |
+
) # no batched dims
|
731 |
+
# Add a singleton batch dimension.
|
732 |
+
# --> shape (1, num_heads, qlen, klen)
|
733 |
+
return values[jnp.newaxis, ...]
|
734 |
+
|
735 |
+
|
736 |
+
# ------------------------------------------------------------------------------
|
737 |
+
# T5 Layernorm - no subtraction of mean or bias.
|
738 |
+
# ------------------------------------------------------------------------------
|
739 |
+
# class LayerNorm(nn.Module):
|
740 |
+
# """T5 Layer normalization operating on the last axis of the input data."""
|
741 |
+
# epsilon: float = 1e-6
|
742 |
+
# dtype: Any = jnp.float32
|
743 |
+
# scale_init: Initializer = nn.initializers.ones
|
744 |
+
|
745 |
+
# @nn.compact
|
746 |
+
# def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
|
747 |
+
# """Applies layer normalization on the input."""
|
748 |
+
# x = jnp.asarray(x, jnp.float32)
|
749 |
+
# features = x.shape[-1]
|
750 |
+
# mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
|
751 |
+
# y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype)
|
752 |
+
# scale = param_with_axes(
|
753 |
+
# 'scale', self.scale_init, (features,), jnp.float32, axes=('embed',))
|
754 |
+
|
755 |
+
# scale = jnp.asarray(scale, self.dtype)
|
756 |
+
# return y * scale
|
757 |
+
|
758 |
+
|
759 |
+
class LayerNorm(nn.Module):
|
760 |
+
"""Layer normalization (https://arxiv.org/abs/1607.06450).
|
761 |
+
Operates on the last axis of the input data.
|
762 |
+
It normalizes the activations of the layer for each given example in a
|
763 |
+
batch independently, rather than across a batch like Batch Normalization.
|
764 |
+
i.e. applies a transformation that maintains the mean activation within
|
765 |
+
each example close to 0 and the activation standard deviation close to 1.
|
766 |
+
Attributes:
|
767 |
+
epsilon: A small float added to variance to avoid dividing by zero.
|
768 |
+
dtype: the dtype of the computation (default: float32).
|
769 |
+
use_bias: If True, bias (beta) is added.
|
770 |
+
use_scale: If True, multiply by scale (gamma). When the next layer is linear
|
771 |
+
(also e.g. nn.relu), this can be disabled since the scaling will be done
|
772 |
+
by the next layer.
|
773 |
+
bias_init: Initializer for bias, by default, zero.
|
774 |
+
scale_init: Initializer for scale, by default, one.
|
775 |
+
"""
|
776 |
+
|
777 |
+
epsilon: float = 1e-6
|
778 |
+
dtype: Any = jnp.float32
|
779 |
+
params_dtype: DType = jnp.float32
|
780 |
+
use_bias: bool = True
|
781 |
+
use_scale: bool = True
|
782 |
+
bias_init: Callable[[PRNGKey, Shape, Any], Array] = nn.initializers.zeros
|
783 |
+
scale_init: Callable[[PRNGKey, Shape, Any], Array] = nn.initializers.ones
|
784 |
+
|
785 |
+
@nn.compact
|
786 |
+
def __call__(self, x):
|
787 |
+
"""Applies layer normalization on the input.
|
788 |
+
Args:
|
789 |
+
x: the inputs
|
790 |
+
Returns:
|
791 |
+
Normalized inputs (the same shape as inputs).
|
792 |
+
"""
|
793 |
+
x = jnp.asarray(x, jnp.float32)
|
794 |
+
features = x.shape[-1]
|
795 |
+
mean = jnp.mean(x, axis=-1, keepdims=True)
|
796 |
+
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
|
797 |
+
var = mean2 - lax.square(mean)
|
798 |
+
mul = lax.rsqrt(var + self.epsilon)
|
799 |
+
if self.use_scale:
|
800 |
+
scale = param_with_axes(
|
801 |
+
"scale",
|
802 |
+
self.scale_init,
|
803 |
+
(features,),
|
804 |
+
self.params_dtype,
|
805 |
+
axes=("embed",),
|
806 |
+
)
|
807 |
+
mul = mul * jnp.asarray(scale, self.dtype)
|
808 |
+
y = (x - mean) * mul
|
809 |
+
if self.use_bias:
|
810 |
+
bias = param_with_axes("bias", self.bias_init, (features,), self.params_dtype, axes=("embed",))
|
811 |
+
y = y + jnp.asarray(bias, self.dtype)
|
812 |
+
return jnp.asarray(y, self.dtype)
|
813 |
+
|
814 |
+
|
815 |
+
# ------------------------------------------------------------------------------
|
816 |
+
# Mask-making utility functions.
|
817 |
+
# ------------------------------------------------------------------------------
|
818 |
+
def make_attention_mask(
|
819 |
+
query_input: Array,
|
820 |
+
key_input: Array,
|
821 |
+
pairwise_fn: Callable = jnp.multiply,
|
822 |
+
extra_batch_dims: int = 0,
|
823 |
+
dtype: DType = jnp.float32,
|
824 |
+
) -> Array:
|
825 |
+
"""Mask-making helper for attention weights.
|
826 |
+
|
827 |
+
In case of 1d inputs (i.e., `[batch, len_q]`, `[batch, len_kv]`, the
|
828 |
+
attention weights will be `[batch, heads, len_q, len_kv]` and this
|
829 |
+
function will produce `[batch, 1, len_q, len_kv]`.
|
830 |
+
|
831 |
+
Args:
|
832 |
+
query_input: a batched, flat input of query_length size
|
833 |
+
key_input: a batched, flat input of key_length size
|
834 |
+
pairwise_fn: broadcasting elementwise comparison function
|
835 |
+
extra_batch_dims: number of extra batch dims to add singleton axes for, none
|
836 |
+
by default
|
837 |
+
dtype: mask return dtype
|
838 |
+
|
839 |
+
Returns:
|
840 |
+
A `[batch, 1, len_q, len_kv]` shaped mask for 1d attention.
|
841 |
+
"""
|
842 |
+
# [batch, len_q, len_kv]
|
843 |
+
mask = pairwise_fn(
|
844 |
+
# [batch, len_q] -> [batch, len_q, 1]
|
845 |
+
jnp.expand_dims(query_input, axis=-1),
|
846 |
+
# [batch, len_q] -> [batch, 1, len_kv]
|
847 |
+
jnp.expand_dims(key_input, axis=-2),
|
848 |
+
)
|
849 |
+
|
850 |
+
# [batch, 1, len_q, len_kv]. This creates the head dim.
|
851 |
+
mask = jnp.expand_dims(mask, axis=-3)
|
852 |
+
mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims)))
|
853 |
+
return mask.astype(dtype)
|
854 |
+
|
855 |
+
|
856 |
+
def make_causal_mask(x: Array, extra_batch_dims: int = 0, dtype: DType = jnp.float32) -> Array:
|
857 |
+
"""Make a causal mask for self-attention.
|
858 |
+
|
859 |
+
In case of 1d inputs (i.e., `[batch, len]`, the self-attention weights
|
860 |
+
will be `[batch, heads, len, len]` and this function will produce a
|
861 |
+
causal mask of shape `[batch, 1, len, len]`.
|
862 |
+
|
863 |
+
Note that a causal mask does not depend on the values of x; it only depends on
|
864 |
+
the shape. If x has padding elements, they will not be treated in a special
|
865 |
+
manner.
|
866 |
+
|
867 |
+
Args:
|
868 |
+
x: input array of shape `[batch, len]`
|
869 |
+
extra_batch_dims: number of batch dims to add singleton axes for, none by
|
870 |
+
default
|
871 |
+
dtype: mask return dtype
|
872 |
+
|
873 |
+
Returns:
|
874 |
+
A `[batch, 1, len, len]` shaped causal mask for 1d attention.
|
875 |
+
"""
|
876 |
+
idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape)
|
877 |
+
return make_attention_mask(idxs, idxs, jnp.greater_equal, extra_batch_dims=extra_batch_dims, dtype=dtype)
|
878 |
+
|
879 |
+
|
880 |
+
def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
|
881 |
+
"""Combine attention masks.
|
882 |
+
|
883 |
+
Args:
|
884 |
+
*masks: set of attention mask arguments to combine, some can be None.
|
885 |
+
dtype: final mask dtype
|
886 |
+
|
887 |
+
Returns:
|
888 |
+
Combined mask, reduced by logical and, returns None if no masks given.
|
889 |
+
"""
|
890 |
+
masks = [m for m in masks if m is not None]
|
891 |
+
if not masks:
|
892 |
+
return None
|
893 |
+
assert all(
|
894 |
+
(x.ndim == masks[0].ndim for x in masks)
|
895 |
+
), f"masks must have same rank: {tuple((x.ndim for x in masks))}"
|
896 |
+
mask, *other_masks = masks
|
897 |
+
for other_mask in other_masks:
|
898 |
+
mask = jnp.logical_and(mask, other_mask)
|
899 |
+
return mask.astype(dtype)
|
900 |
+
|
901 |
+
|
902 |
+
def combine_biases(*masks: Optional[Array]):
|
903 |
+
"""Combine attention biases.
|
904 |
+
|
905 |
+
Args:
|
906 |
+
*masks: set of attention bias arguments to combine, some can be None.
|
907 |
+
|
908 |
+
Returns:
|
909 |
+
Combined mask, reduced by summation, returns None if no masks given.
|
910 |
+
"""
|
911 |
+
masks = [m for m in masks if m is not None]
|
912 |
+
if not masks:
|
913 |
+
return None
|
914 |
+
assert all(
|
915 |
+
(x.ndim == masks[0].ndim for x in masks)
|
916 |
+
), f"masks must have same rank: {tuple((x.ndim for x in masks))}"
|
917 |
+
mask, *other_masks = masks
|
918 |
+
for other_mask in other_masks:
|
919 |
+
mask = mask + other_mask
|
920 |
+
return mask
|
921 |
+
|
922 |
+
|
923 |
+
def make_decoder_mask(
|
924 |
+
decoder_target_tokens: Array,
|
925 |
+
dtype: DType,
|
926 |
+
decoder_causal_attention: Optional[Array] = None,
|
927 |
+
decoder_segment_ids: Optional[Array] = None,
|
928 |
+
) -> Array:
|
929 |
+
"""Compute the self-attention mask for a decoder.
|
930 |
+
|
931 |
+
Decoder mask is formed by combining a causal mask, a padding mask and an
|
932 |
+
optional packing mask. If decoder_causal_attention is passed, it makes the
|
933 |
+
masking non-causal for positions that have value of 1.
|
934 |
+
|
935 |
+
A prefix LM is applied to a dataset which has a notion of "inputs" and
|
936 |
+
"targets", e.g., a machine translation task. The inputs and targets are
|
937 |
+
concatenated to form a new target. `decoder_target_tokens` is the concatenated
|
938 |
+
decoder output tokens.
|
939 |
+
|
940 |
+
The "inputs" portion of the concatenated sequence can attend to other "inputs"
|
941 |
+
tokens even for those at a later time steps. In order to control this
|
942 |
+
behavior, `decoder_causal_attention` is necessary. This is a binary mask with
|
943 |
+
a value of 1 indicating that the position belonged to "inputs" portion of the
|
944 |
+
original dataset.
|
945 |
+
|
946 |
+
Example:
|
947 |
+
|
948 |
+
Suppose we have a dataset with two examples.
|
949 |
+
|
950 |
+
ds = [{"inputs": [6, 7], "targets": [8]},
|
951 |
+
{"inputs": [3, 4], "targets": [5]}]
|
952 |
+
|
953 |
+
After the data preprocessing with packing, the two examples are packed into
|
954 |
+
one example with the following three fields (some fields are skipped for
|
955 |
+
simplicity).
|
956 |
+
|
957 |
+
decoder_target_tokens = [[6, 7, 8, 3, 4, 5, 0]]
|
958 |
+
decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]
|
959 |
+
decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]]
|
960 |
+
|
961 |
+
where each array has [batch, length] shape with batch size being 1. Then,
|
962 |
+
this function computes the following mask.
|
963 |
+
|
964 |
+
mask = [[[[1, 1, 0, 0, 0, 0, 0],
|
965 |
+
[1, 1, 0, 0, 0, 0, 0],
|
966 |
+
[1, 1, 1, 0, 0, 0, 0],
|
967 |
+
[0, 0, 0, 1, 1, 0, 0],
|
968 |
+
[0, 0, 0, 1, 1, 0, 0],
|
969 |
+
[0, 0, 0, 1, 1, 1, 0],
|
970 |
+
[0, 0, 0, 0, 0, 0, 0]]]]
|
971 |
+
|
972 |
+
mask[b, 1, :, :] represents the mask for the example `b` in the batch.
|
973 |
+
Because mask is for a self-attention layer, the mask's shape is a square of
|
974 |
+
shape [query length, key length].
|
975 |
+
|
976 |
+
mask[b, 1, i, j] = 1 means that the query token at position i can attend to
|
977 |
+
the key token at position j.
|
978 |
+
|
979 |
+
Args:
|
980 |
+
decoder_target_tokens: decoder output tokens. [batch, length]
|
981 |
+
dtype: dtype of the output mask.
|
982 |
+
decoder_causal_attention: a binary mask indicating which position should
|
983 |
+
only attend to earlier positions in the sequence. Others will attend
|
984 |
+
bidirectionally. [batch, length]
|
985 |
+
decoder_segment_ids: decoder segmentation info for packed examples. [batch,
|
986 |
+
length]
|
987 |
+
|
988 |
+
Returns:
|
989 |
+
the combined decoder mask.
|
990 |
+
"""
|
991 |
+
masks = []
|
992 |
+
# The same mask is applied to all attention heads. So the head dimension is 1,
|
993 |
+
# i.e., the mask will be broadcast along the heads dim.
|
994 |
+
# [batch, 1, length, length]
|
995 |
+
causal_mask = make_causal_mask(decoder_target_tokens, dtype=dtype)
|
996 |
+
|
997 |
+
# Positions with value 1 in `decoder_causal_attneition` can attend
|
998 |
+
# bidirectionally.
|
999 |
+
if decoder_causal_attention is not None:
|
1000 |
+
# [batch, 1, length, length]
|
1001 |
+
inputs_mask = make_attention_mask(
|
1002 |
+
decoder_causal_attention,
|
1003 |
+
decoder_causal_attention,
|
1004 |
+
jnp.logical_and,
|
1005 |
+
dtype=dtype,
|
1006 |
+
)
|
1007 |
+
masks.append(jnp.logical_or(causal_mask, inputs_mask).astype(dtype))
|
1008 |
+
else:
|
1009 |
+
masks.append(causal_mask)
|
1010 |
+
|
1011 |
+
# Padding mask.
|
1012 |
+
masks.append(make_attention_mask(decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype))
|
1013 |
+
|
1014 |
+
# Packing mask
|
1015 |
+
if decoder_segment_ids is not None:
|
1016 |
+
masks.append(make_attention_mask(decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype))
|
1017 |
+
|
1018 |
+
return combine_masks(*masks, dtype=dtype)
|
1019 |
+
|
1020 |
+
|
1021 |
+
def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding:
|
1022 |
+
""" "Canonicalizes conv padding to a jax.lax supported format."""
|
1023 |
+
if isinstance(padding, str):
|
1024 |
+
return padding
|
1025 |
+
if isinstance(padding, int):
|
1026 |
+
return [(padding, padding)] * rank
|
1027 |
+
if isinstance(padding, Sequence) and len(padding) == rank:
|
1028 |
+
new_pad = []
|
1029 |
+
for p in padding:
|
1030 |
+
if isinstance(p, int):
|
1031 |
+
new_pad.append((p, p))
|
1032 |
+
elif isinstance(p, tuple) and len(p) == 2:
|
1033 |
+
new_pad.append(p)
|
1034 |
+
else:
|
1035 |
+
break
|
1036 |
+
if len(new_pad) == rank:
|
1037 |
+
return new_pad
|
1038 |
+
raise ValueError(
|
1039 |
+
f"Invalid padding format: {padding}, should be str, int,"
|
1040 |
+
f" or a sequence of len {rank} where each element is an"
|
1041 |
+
" int or pair of ints."
|
1042 |
+
)
|
1043 |
+
|
1044 |
+
|
1045 |
+
def _conv_dimension_numbers(input_shape):
|
1046 |
+
"""Computes the dimension numbers based on the input shape."""
|
1047 |
+
ndim = len(input_shape)
|
1048 |
+
lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1))
|
1049 |
+
rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2))
|
1050 |
+
out_spec = lhs_spec
|
1051 |
+
return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
|
1052 |
+
|
1053 |
+
|
1054 |
+
class _Conv(nn.Module):
|
1055 |
+
"""Convolution Module wrapping `lax.conv_general_dilated[_local]`.
|
1056 |
+
|
1057 |
+
Attributes:
|
1058 |
+
features: number of convolution filters.
|
1059 |
+
kernel_size: shape of the convolutional kernel. For 1D convolution,
|
1060 |
+
the kernel size can be passed as an integer. For all other cases, it must
|
1061 |
+
be a sequence of integers.
|
1062 |
+
strides: an integer or a sequence of `n` integers, representing the
|
1063 |
+
inter-window strides (default: 1).
|
1064 |
+
padding: either the string `'SAME'`, the string `'VALID'`, the string
|
1065 |
+
`'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low,
|
1066 |
+
high)` integer pairs that give the padding to apply before and after each
|
1067 |
+
spatial dimension. A single int is interpeted as applying the same padding
|
1068 |
+
in all dims and passign a single int in a sequence causes the same padding
|
1069 |
+
to be used on both sides. `'CAUSAL'` padding for a 1D convolution will
|
1070 |
+
left-pad the convolution axis, resulting in same-sized output.
|
1071 |
+
input_dilation: an integer or a sequence of `n` integers, giving the
|
1072 |
+
dilation factor to apply in each spatial dimension of `inputs`
|
1073 |
+
(default: 1). Convolution with input dilation `d` is equivalent to
|
1074 |
+
transposed convolution with stride `d`.
|
1075 |
+
kernel_dilation: an integer or a sequence of `n` integers, giving the
|
1076 |
+
dilation factor to apply in each spatial dimension of the convolution
|
1077 |
+
kernel (default: 1). Convolution with kernel dilation
|
1078 |
+
is also known as 'atrous convolution'.
|
1079 |
+
feature_group_count: integer, default 1. If specified divides the input
|
1080 |
+
features into groups.
|
1081 |
+
use_bias: whether to add a bias to the output (default: True).
|
1082 |
+
mask: Optional mask for the weights during masked convolution. The mask must
|
1083 |
+
be the same shape as the convolution weight matrix.
|
1084 |
+
dtype: the dtype of the computation (default: infer from input and params).
|
1085 |
+
params_dtype: the dtype passed to parameter initializers (default: float32).
|
1086 |
+
precision: numerical precision of the computation see `jax.lax.Precision`
|
1087 |
+
for details.
|
1088 |
+
kernel_init: initializer for the convolutional kernel.
|
1089 |
+
bias_init: initializer for the bias.
|
1090 |
+
"""
|
1091 |
+
|
1092 |
+
features: int
|
1093 |
+
kernel_size: Sequence[int]
|
1094 |
+
strides: Union[None, int, Sequence[int]] = 1
|
1095 |
+
padding: PaddingLike = "SAME"
|
1096 |
+
input_dilation: Union[None, int, Sequence[int]] = 1
|
1097 |
+
kernel_dilation: Union[None, int, Sequence[int]] = 1
|
1098 |
+
feature_group_count: int = 1
|
1099 |
+
use_bias: bool = True
|
1100 |
+
mask: Optional[Array] = None
|
1101 |
+
dtype: Optional[DType] = None
|
1102 |
+
params_dtype: DType = jnp.float32
|
1103 |
+
precision: PrecisionLike = None
|
1104 |
+
kernel_init: Callable[[PRNGKey, Shape, DType], Array] = nn.initializers.lecun_normal()
|
1105 |
+
bias_init: Callable[[PRNGKey, Shape, DType], Array] = nn.initializers.zeros
|
1106 |
+
conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated
|
1107 |
+
kernel_axes: Tuple[str, ...] = ()
|
1108 |
+
|
1109 |
+
@property
|
1110 |
+
def shared_weights(self) -> bool: # type: ignore
|
1111 |
+
"""Defines whether weights are shared or not between different pixels.
|
1112 |
+
|
1113 |
+
Returns:
|
1114 |
+
`True` to use shared weights in convolution (regular convolution).
|
1115 |
+
`False` to use different weights at different pixels, a.k.a.
|
1116 |
+
"locally connected layer", "unshared convolution", or "local convolution".
|
1117 |
+
|
1118 |
+
"""
|
1119 |
+
...
|
1120 |
+
|
1121 |
+
@nn.compact
|
1122 |
+
def __call__(self, inputs: Array) -> Array:
|
1123 |
+
"""Applies a (potentially unshared) convolution to the inputs.
|
1124 |
+
|
1125 |
+
Args:
|
1126 |
+
inputs: input data with dimensions (*batch_dims, spatial_dims...,
|
1127 |
+
features). This is the channels-last convention, i.e. NHWC for a 2d
|
1128 |
+
convolution and NDHWC for a 3D convolution. Note: this is different from
|
1129 |
+
the input convention used by `lax.conv_general_dilated`, which puts the
|
1130 |
+
spatial dimensions last.
|
1131 |
+
Note: If the input has more than 1 batch dimension, all batch dimensions
|
1132 |
+
are flattened into a single dimension for the convolution and restored
|
1133 |
+
before returning. In some cases directly vmap'ing the layer may yield
|
1134 |
+
better performance than this default flattening approach. If the input
|
1135 |
+
lacks a batch dimension it will be added for the convolution and removed
|
1136 |
+
n return, an allowance made to enable writing single-example code.
|
1137 |
+
|
1138 |
+
Returns:
|
1139 |
+
The convolved data.
|
1140 |
+
"""
|
1141 |
+
|
1142 |
+
if isinstance(self.kernel_size, int):
|
1143 |
+
raise TypeError(
|
1144 |
+
"Expected Conv kernel_size to be a"
|
1145 |
+
" tuple/list of integers (eg.: [3, 3]) but got"
|
1146 |
+
f" {self.kernel_size}."
|
1147 |
+
)
|
1148 |
+
else:
|
1149 |
+
kernel_size = tuple(self.kernel_size)
|
1150 |
+
|
1151 |
+
def maybe_broadcast(x: Optional[Union[int, Sequence[int]]]) -> Tuple[int, ...]:
|
1152 |
+
if x is None:
|
1153 |
+
# backward compatibility with using None as sentinel for
|
1154 |
+
# broadcast 1
|
1155 |
+
x = 1
|
1156 |
+
if isinstance(x, int):
|
1157 |
+
return (x,) * len(kernel_size)
|
1158 |
+
return tuple(x)
|
1159 |
+
|
1160 |
+
# Combine all input batch dimensions into a single leading batch axis.
|
1161 |
+
num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1)
|
1162 |
+
if num_batch_dimensions != 1:
|
1163 |
+
input_batch_shape = inputs.shape[:num_batch_dimensions]
|
1164 |
+
total_batch_size = int(np.prod(input_batch_shape))
|
1165 |
+
flat_input_shape = (total_batch_size,) + inputs.shape[num_batch_dimensions:]
|
1166 |
+
inputs = jnp.reshape(inputs, flat_input_shape)
|
1167 |
+
|
1168 |
+
# self.strides or (1,) * (inputs.ndim - 2)
|
1169 |
+
strides = maybe_broadcast(self.strides)
|
1170 |
+
input_dilation = maybe_broadcast(self.input_dilation)
|
1171 |
+
kernel_dilation = maybe_broadcast(self.kernel_dilation)
|
1172 |
+
|
1173 |
+
padding_lax = canonicalize_padding(self.padding, len(kernel_size))
|
1174 |
+
if padding_lax == "CIRCULAR":
|
1175 |
+
kernel_size_dilated = [(k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation)]
|
1176 |
+
zero_pad: List[Tuple[int, int]] = [(0, 0)]
|
1177 |
+
pads = zero_pad + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] + [(0, 0)]
|
1178 |
+
inputs = jnp.pad(inputs, pads, mode="wrap")
|
1179 |
+
padding_lax = "VALID"
|
1180 |
+
elif padding_lax == "CAUSAL":
|
1181 |
+
if len(kernel_size) != 1:
|
1182 |
+
raise ValueError("Causal padding is only implemented for 1D convolutions.")
|
1183 |
+
left_pad = kernel_dilation[0] * (kernel_size[0] - 1)
|
1184 |
+
pads = [(0, 0), (left_pad, 0), (0, 0)]
|
1185 |
+
inputs = jnp.pad(inputs, pads)
|
1186 |
+
padding_lax = "VALID"
|
1187 |
+
|
1188 |
+
dimension_numbers = _conv_dimension_numbers(inputs.shape)
|
1189 |
+
in_features = jnp.shape(inputs)[-1]
|
1190 |
+
|
1191 |
+
if self.shared_weights:
|
1192 |
+
# One shared convolutional kernel for all pixels in the output.
|
1193 |
+
assert in_features % self.feature_group_count == 0
|
1194 |
+
kernel_shape = kernel_size + (
|
1195 |
+
in_features // self.feature_group_count,
|
1196 |
+
self.features,
|
1197 |
+
)
|
1198 |
+
|
1199 |
+
else:
|
1200 |
+
if self.feature_group_count != 1:
|
1201 |
+
raise NotImplementedError(
|
1202 |
+
"`lax.conv_general_dilated_local` does not support "
|
1203 |
+
f"`feature_group_count != 1`, got `{self.feature_group_count}`."
|
1204 |
+
)
|
1205 |
+
|
1206 |
+
# Need to know the spatial output shape of a standard convolution to
|
1207 |
+
# create the unshared convolution kernel.
|
1208 |
+
conv_output_shape = jax.eval_shape(
|
1209 |
+
lambda lhs, rhs: self.conv_general_dilated( # pylint: disable=g-long-lambda
|
1210 |
+
lhs=lhs,
|
1211 |
+
rhs=rhs,
|
1212 |
+
window_strides=strides,
|
1213 |
+
padding=padding_lax,
|
1214 |
+
dimension_numbers=dimension_numbers,
|
1215 |
+
lhs_dilation=input_dilation,
|
1216 |
+
rhs_dilation=kernel_dilation,
|
1217 |
+
),
|
1218 |
+
inputs,
|
1219 |
+
jax.ShapedArray(kernel_size + (in_features, self.features), inputs.dtype),
|
1220 |
+
).shape
|
1221 |
+
|
1222 |
+
# One (unshared) convolutional kernel per each pixel in the output.
|
1223 |
+
kernel_shape = conv_output_shape[1:-1] + (
|
1224 |
+
np.prod(kernel_size) * in_features,
|
1225 |
+
self.features,
|
1226 |
+
)
|
1227 |
+
|
1228 |
+
if self.mask is not None and self.mask.shape != kernel_shape:
|
1229 |
+
raise ValueError(
|
1230 |
+
"Mask needs to have the same shape as weights. " f"Shapes are: {self.mask.shape}, {kernel_shape}"
|
1231 |
+
)
|
1232 |
+
|
1233 |
+
kernel = param_with_axes(
|
1234 |
+
"kernel",
|
1235 |
+
self.kernel_init,
|
1236 |
+
kernel_shape,
|
1237 |
+
self.params_dtype,
|
1238 |
+
axes=self.kernel_axes,
|
1239 |
+
)
|
1240 |
+
|
1241 |
+
if self.mask is not None:
|
1242 |
+
kernel *= self.mask
|
1243 |
+
|
1244 |
+
if self.use_bias:
|
1245 |
+
if self.shared_weights:
|
1246 |
+
# One bias weight per output channel, shared between pixels.
|
1247 |
+
bias_shape = (self.features,)
|
1248 |
+
else:
|
1249 |
+
# One bias weight per output entry, unshared betwen pixels.
|
1250 |
+
bias_shape = conv_output_shape[1:]
|
1251 |
+
|
1252 |
+
bias = param_with_axes(
|
1253 |
+
"bias",
|
1254 |
+
self.bias_init,
|
1255 |
+
bias_shape,
|
1256 |
+
self.params_dtype,
|
1257 |
+
axes=(self.kernel_axes[-1],),
|
1258 |
+
)
|
1259 |
+
else:
|
1260 |
+
bias = None
|
1261 |
+
|
1262 |
+
inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype)
|
1263 |
+
if self.shared_weights:
|
1264 |
+
y = self.conv_general_dilated(
|
1265 |
+
inputs,
|
1266 |
+
kernel,
|
1267 |
+
strides,
|
1268 |
+
padding_lax,
|
1269 |
+
lhs_dilation=input_dilation,
|
1270 |
+
rhs_dilation=kernel_dilation,
|
1271 |
+
dimension_numbers=dimension_numbers,
|
1272 |
+
feature_group_count=self.feature_group_count,
|
1273 |
+
precision=self.precision,
|
1274 |
+
)
|
1275 |
+
else:
|
1276 |
+
y = lax.conv_general_dilated_local(
|
1277 |
+
lhs=inputs,
|
1278 |
+
rhs=kernel,
|
1279 |
+
window_strides=strides,
|
1280 |
+
padding=padding_lax,
|
1281 |
+
filter_shape=kernel_size,
|
1282 |
+
lhs_dilation=input_dilation,
|
1283 |
+
rhs_dilation=kernel_dilation,
|
1284 |
+
dimension_numbers=dimension_numbers,
|
1285 |
+
precision=self.precision,
|
1286 |
+
)
|
1287 |
+
|
1288 |
+
if self.use_bias:
|
1289 |
+
bias = bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape)
|
1290 |
+
y += bias
|
1291 |
+
|
1292 |
+
if num_batch_dimensions != 1:
|
1293 |
+
output_shape = input_batch_shape + y.shape[1:]
|
1294 |
+
y = jnp.reshape(y, output_shape)
|
1295 |
+
return y
|
1296 |
+
|
1297 |
+
|
1298 |
+
class Conv(_Conv):
|
1299 |
+
"""Convolution Module wrapping `lax.conv_general_dilated`.
|
1300 |
+
|
1301 |
+
Attributes:
|
1302 |
+
features: number of convolution filters.
|
1303 |
+
kernel_size: shape of the convolutional kernel. For 1D convolution,
|
1304 |
+
the kernel size can be passed as an integer. For all other cases, it must
|
1305 |
+
be a sequence of integers.
|
1306 |
+
strides: an integer or a sequence of `n` integers, representing the
|
1307 |
+
inter-window strides (default: 1).
|
1308 |
+
padding: either the string `'SAME'`, the string `'VALID'`, the string
|
1309 |
+
`'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low,
|
1310 |
+
high)` integer pairs that give the padding to apply before and after each
|
1311 |
+
spatial dimension. A single int is interpeted as applying the same padding
|
1312 |
+
in all dims and passign a single int in a sequence causes the same padding
|
1313 |
+
to be used on both sides. `'CAUSAL'` padding for a 1D convolution will
|
1314 |
+
left-pad the convolution axis, resulting in same-sized output.
|
1315 |
+
input_dilation: an integer or a sequence of `n` integers, giving the
|
1316 |
+
dilation factor to apply in each spatial dimension of `inputs`
|
1317 |
+
(default: 1). Convolution with input dilation `d` is equivalent to
|
1318 |
+
transposed convolution with stride `d`.
|
1319 |
+
kernel_dilation: an integer or a sequence of `n` integers, giving the
|
1320 |
+
dilation factor to apply in each spatial dimension of the convolution
|
1321 |
+
kernel (default: 1). Convolution with kernel dilation
|
1322 |
+
is also known as 'atrous convolution'.
|
1323 |
+
feature_group_count: integer, default 1. If specified divides the input
|
1324 |
+
features into groups.
|
1325 |
+
use_bias: whether to add a bias to the output (default: True).
|
1326 |
+
mask: Optional mask for the weights during masked convolution. The mask must
|
1327 |
+
be the same shape as the convolution weight matrix.
|
1328 |
+
dtype: the dtype of the computation (default: infer from input and params).
|
1329 |
+
params_dtype: the dtype passed to parameter initializers (default: float32).
|
1330 |
+
precision: numerical precision of the computation see `jax.lax.Precision`
|
1331 |
+
for details.
|
1332 |
+
kernel_init: initializer for the convolutional kernel.
|
1333 |
+
bias_init: initializer for the bias.
|
1334 |
+
"""
|
1335 |
+
|
1336 |
+
@property
|
1337 |
+
def shared_weights(self) -> bool:
|
1338 |
+
return True
|
flax/distil_whisper/modeling_flax_whisper.py
ADDED
@@ -0,0 +1,2135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Flax whisper model."""
|
16 |
+
|
17 |
+
import random
|
18 |
+
from functools import partial
|
19 |
+
from typing import Dict, Optional, Tuple, Union
|
20 |
+
|
21 |
+
import flax.linen as nn
|
22 |
+
import jax
|
23 |
+
import jax.numpy as jnp
|
24 |
+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
25 |
+
from flax.linen import combine_masks, make_causal_mask
|
26 |
+
from flax.linen.attention import dot_product_attention_weights
|
27 |
+
from flax.linen.partitioning import remat, scan_with_axes
|
28 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
29 |
+
from jax import lax
|
30 |
+
from jax.random import PRNGKey
|
31 |
+
from transformers import WhisperConfig
|
32 |
+
from transformers.generation.flax_logits_process import (
|
33 |
+
FlaxLogitsProcessor,
|
34 |
+
FlaxLogitsProcessorList,
|
35 |
+
FlaxWhisperTimeStampLogitsProcessor,
|
36 |
+
)
|
37 |
+
from transformers.modeling_flax_outputs import (
|
38 |
+
FlaxBaseModelOutput,
|
39 |
+
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
40 |
+
FlaxCausalLMOutputWithCrossAttentions,
|
41 |
+
FlaxSeq2SeqLMOutput,
|
42 |
+
FlaxSeq2SeqModelOutput,
|
43 |
+
)
|
44 |
+
from transformers.modeling_flax_utils import (
|
45 |
+
ACT2FN,
|
46 |
+
FlaxPreTrainedModel,
|
47 |
+
append_call_sample_docstring,
|
48 |
+
append_replace_return_docstrings,
|
49 |
+
overwrite_call_docstring,
|
50 |
+
)
|
51 |
+
from transformers.utils import (
|
52 |
+
add_start_docstrings,
|
53 |
+
add_start_docstrings_to_model_forward,
|
54 |
+
logging,
|
55 |
+
replace_return_docstrings,
|
56 |
+
)
|
57 |
+
|
58 |
+
from .layers import Conv, DenseGeneral, Embed, LayerNorm, with_sharding_constraint
|
59 |
+
|
60 |
+
|
61 |
+
logger = logging.get_logger(__name__)
|
62 |
+
|
63 |
+
|
64 |
+
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
|
65 |
+
_CONFIG_FOR_DOC = "WhisperConfig"
|
66 |
+
|
67 |
+
|
68 |
+
WHISPER_START_DOCSTRING = r"""
|
69 |
+
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
70 |
+
library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads
|
71 |
+
etc.) This model is also a Flax Linen
|
72 |
+
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
|
73 |
+
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
|
74 |
+
Finally, this model supports inherent JAX features such as:
|
75 |
+
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
76 |
+
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
77 |
+
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
78 |
+
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
79 |
+
|
80 |
+
Parameters:
|
81 |
+
config ([`WhisperConfig`]): Model configuration class with all the parameters of the model.
|
82 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
83 |
+
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
84 |
+
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
85 |
+
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
86 |
+
`jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision
|
87 |
+
inference on GPUs or TPUs. If specified all the computation will be performed with the given `dtype`.
|
88 |
+
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
89 |
+
parameters.** If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`]
|
90 |
+
and [`~FlaxPreTrainedModel.to_bf16`].
|
91 |
+
"""
|
92 |
+
|
93 |
+
WHISPER_INPUTS_DOCSTRING = r"""
|
94 |
+
Args:
|
95 |
+
input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
|
96 |
+
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
|
97 |
+
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
|
98 |
+
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
|
99 |
+
[`WhisperFeatureExtractor`] should be used for extracting the features, padding and conversion into a
|
100 |
+
tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`]
|
101 |
+
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
102 |
+
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but
|
103 |
+
is not used. By default the silence in the input log mel spectrogram are ignored.
|
104 |
+
decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
105 |
+
Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
|
106 |
+
[`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
|
107 |
+
[What are decoder input IDs?](../glossary#decoder-input-ids) Whisper uses the `decoder_start_token_id` as
|
108 |
+
the starting token for `decoder_input_ids` generation.
|
109 |
+
decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
110 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
111 |
+
be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1
|
112 |
+
in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
|
113 |
+
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
114 |
+
Whisper does not use `position_ids` in the encoder as `input_features` is always the same size and doesn't
|
115 |
+
use masking, but this argument is preserved for compatibility. By default the silence in the input log mel
|
116 |
+
spectrogram are ignored.
|
117 |
+
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
118 |
+
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
|
119 |
+
range `[0, config.max_position_embeddings - 1]`.
|
120 |
+
output_attentions (`bool`, *optional*):
|
121 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
122 |
+
tensors for more detail.
|
123 |
+
output_hidden_states (`bool`, *optional*):
|
124 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
125 |
+
more detail.
|
126 |
+
return_dict (`bool`, *optional*):
|
127 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
128 |
+
"""
|
129 |
+
|
130 |
+
WHISPER_ENCODE_INPUTS_DOCSTRING = r"""
|
131 |
+
Args:
|
132 |
+
input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
|
133 |
+
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
|
134 |
+
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
|
135 |
+
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
|
136 |
+
[`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
|
137 |
+
tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`].
|
138 |
+
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
139 |
+
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but
|
140 |
+
is not used. By default the silence in the input log mel spectrogram are ignored.
|
141 |
+
output_attentions (`bool`, *optional*):
|
142 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
143 |
+
tensors for more detail.
|
144 |
+
output_hidden_states (`bool`, *optional*):
|
145 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
146 |
+
more detail.
|
147 |
+
return_dict (`bool`, *optional*):
|
148 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
149 |
+
"""
|
150 |
+
|
151 |
+
WHISPER_DECODE_INPUTS_DOCSTRING = r"""
|
152 |
+
Args:
|
153 |
+
decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`):
|
154 |
+
Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
|
155 |
+
[`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
|
156 |
+
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
157 |
+
encoder_outputs (`tuple(tuple(numpy.ndarray)`):
|
158 |
+
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
159 |
+
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
160 |
+
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
161 |
+
encoder_attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
162 |
+
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
|
163 |
+
but it is not used. By default the silence in the input log mel spectrogram are ignored.
|
164 |
+
decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
165 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
166 |
+
be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1
|
167 |
+
in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
|
168 |
+
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
169 |
+
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
|
170 |
+
range `[0, config.max_position_embeddings - 1]`.
|
171 |
+
past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
|
172 |
+
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
|
173 |
+
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
|
174 |
+
output_attentions (`bool`, *optional*):
|
175 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
176 |
+
tensors for more detail.
|
177 |
+
output_hidden_states (`bool`, *optional*):
|
178 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
179 |
+
more detail.
|
180 |
+
return_dict (`bool`, *optional*):
|
181 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
182 |
+
"""
|
183 |
+
|
184 |
+
|
185 |
+
class FlaxStaticForceTokensLogitsProcessor(FlaxLogitsProcessor):
|
186 |
+
r"""
|
187 |
+
[`FlaxLogitsProcessor`] that takes a list of pairs of integers which indicates a mapping from generation indices to
|
188 |
+
token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens
|
189 |
+
to `-inf` so that they are sampled at their corresponding index. This is a static version of the `transformers` logit
|
190 |
+
processor [`FlaxForceTokensLogitsProcessor`] that is compatible with sharded forced tokens.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
force_token_map (`list`):
|
194 |
+
Map giving token ids and indices where they will be forced to be sampled.
|
195 |
+
"""
|
196 |
+
|
197 |
+
def __init__(self, force_token_map):
|
198 |
+
# The generic `transformers` logit processor builds `force_token_array` as a dictionary - this is not a valid
|
199 |
+
# JAX type, and so we switch to using a JAX array instead
|
200 |
+
force_token_map = jnp.array(force_token_map)
|
201 |
+
# Converts the array of format [[index, token]] containing the tokens to be forced to an array, where the
|
202 |
+
# index of the array corresponds to the index of the token to be forced. For XLA compatibility,
|
203 |
+
# indexes without forced tokens will have a negative value. Note that the last token we ever need to force in
|
204 |
+
# Whisper is at position 3, so we only construct an array up to this index. The native version constructs a tensor
|
205 |
+
# dynamically according to the length of the `force_token_map`. Array shapes need to be concrete for XLA compatibility,
|
206 |
+
# so this is not permitted here.
|
207 |
+
force_token_array = jnp.ones(3, dtype=jnp.int32) * -1
|
208 |
+
for index, token in force_token_map:
|
209 |
+
force_token_array = force_token_array.at[index].set(token)
|
210 |
+
self.force_token_array = jnp.int32(force_token_array)
|
211 |
+
|
212 |
+
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
|
213 |
+
def _force_token(generation_idx):
|
214 |
+
batch_size = scores.shape[0]
|
215 |
+
current_token = self.force_token_array[generation_idx]
|
216 |
+
|
217 |
+
new_scores = jnp.ones_like(scores, dtype=scores.dtype) * -float("inf")
|
218 |
+
updates = jnp.zeros((batch_size, 1), dtype=scores.dtype)
|
219 |
+
new_scores = lax.dynamic_update_slice(new_scores, updates, (0, current_token))
|
220 |
+
return new_scores
|
221 |
+
|
222 |
+
scores = lax.cond(
|
223 |
+
cur_len >= self.force_token_array.shape[0],
|
224 |
+
# If the current length is geq than the length of force_token_array, the processor does nothing.
|
225 |
+
lambda: scores,
|
226 |
+
# Otherwise, it may force a certain token.
|
227 |
+
lambda: lax.cond(
|
228 |
+
self.force_token_array[cur_len] >= 0,
|
229 |
+
# Only valid (positive) tokens are forced
|
230 |
+
lambda: _force_token(cur_len),
|
231 |
+
# Otherwise, the processor does nothing.
|
232 |
+
lambda: scores,
|
233 |
+
),
|
234 |
+
)
|
235 |
+
return scores
|
236 |
+
|
237 |
+
|
238 |
+
class FlaxWhisperAttention(nn.Module):
|
239 |
+
config: WhisperConfig
|
240 |
+
embed_dim: int
|
241 |
+
num_heads: int
|
242 |
+
dropout: float = 0.0
|
243 |
+
causal: bool = False
|
244 |
+
bias: bool = True
|
245 |
+
dtype: jnp.dtype = jnp.float32
|
246 |
+
params_dtype: jnp.dtype = jnp.float32
|
247 |
+
|
248 |
+
def setup(self) -> None:
|
249 |
+
self.head_dim = self.embed_dim // self.num_heads
|
250 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
251 |
+
raise ValueError(
|
252 |
+
"embed_dim must be divisible by num_heads (got `embed_dim`:"
|
253 |
+
f" {self.embed_dim} and `num_heads`: {self.num_heads})."
|
254 |
+
)
|
255 |
+
|
256 |
+
dense = partial(
|
257 |
+
DenseGeneral,
|
258 |
+
self.embed_dim,
|
259 |
+
axis=-1,
|
260 |
+
dtype=self.dtype,
|
261 |
+
params_dtype=self.params_dtype,
|
262 |
+
kernel_axes=("embed", "joined_kv"),
|
263 |
+
)
|
264 |
+
|
265 |
+
self.q_proj = dense(use_bias=self.bias)
|
266 |
+
self.k_proj = dense(use_bias=False)
|
267 |
+
self.v_proj = dense(use_bias=self.bias)
|
268 |
+
|
269 |
+
self.out_proj = DenseGeneral(
|
270 |
+
self.embed_dim,
|
271 |
+
axis=-1,
|
272 |
+
dtype=self.dtype,
|
273 |
+
params_dtype=self.params_dtype,
|
274 |
+
kernel_axes=("joined_kv", "embed"),
|
275 |
+
use_bias=self.bias,
|
276 |
+
)
|
277 |
+
|
278 |
+
if self.causal:
|
279 |
+
self.causal_mask = make_causal_mask(
|
280 |
+
jnp.ones((1, self.config.max_target_positions), dtype="bool"),
|
281 |
+
dtype="bool",
|
282 |
+
)
|
283 |
+
|
284 |
+
def __call__(
|
285 |
+
self,
|
286 |
+
hidden_states: jnp.ndarray,
|
287 |
+
key_value_states: Optional[jnp.ndarray] = None,
|
288 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
289 |
+
init_cache: bool = False,
|
290 |
+
deterministic: bool = True,
|
291 |
+
) -> Tuple[jnp.ndarray]:
|
292 |
+
is_cross_attention = key_value_states is not None
|
293 |
+
batch_size = hidden_states.shape[0]
|
294 |
+
|
295 |
+
query_states = self.q_proj(hidden_states)
|
296 |
+
|
297 |
+
if is_cross_attention:
|
298 |
+
key_states = self.k_proj(key_value_states)
|
299 |
+
value_states = self.v_proj(key_value_states)
|
300 |
+
else:
|
301 |
+
key_states = self.k_proj(hidden_states)
|
302 |
+
value_states = self.v_proj(hidden_states)
|
303 |
+
|
304 |
+
query_states = self._split_heads(query_states)
|
305 |
+
key_states = self._split_heads(key_states)
|
306 |
+
value_states = self._split_heads(value_states)
|
307 |
+
|
308 |
+
query_states = with_sharding_constraint(query_states, ("batch", "length", "heads", "kv"))
|
309 |
+
key_states = with_sharding_constraint(key_states, ("batch", "length", "heads", "kv"))
|
310 |
+
value_states = with_sharding_constraint(value_states, ("batch", "length", "heads", "kv"))
|
311 |
+
|
312 |
+
if self.causal:
|
313 |
+
query_length, key_length = query_states.shape[1], key_states.shape[1]
|
314 |
+
if self.has_variable("cache", "cached_key"):
|
315 |
+
mask_shift = self.variables["cache"]["cache_index"]
|
316 |
+
# max_length of cached_key is last dim
|
317 |
+
max_decoder_length = self.variables["cache"]["cached_key"].shape[-1]
|
318 |
+
causal_mask = lax.dynamic_slice(
|
319 |
+
self.causal_mask,
|
320 |
+
(0, 0, mask_shift, 0),
|
321 |
+
(1, 1, query_length, max_decoder_length),
|
322 |
+
)
|
323 |
+
else:
|
324 |
+
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
325 |
+
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
326 |
+
|
327 |
+
# combine masks if needed
|
328 |
+
if attention_mask is not None and self.causal:
|
329 |
+
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
330 |
+
attention_mask = combine_masks(attention_mask, causal_mask)
|
331 |
+
elif self.causal:
|
332 |
+
attention_mask = causal_mask
|
333 |
+
elif attention_mask is not None:
|
334 |
+
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
335 |
+
|
336 |
+
# During fast autoregressive decoding, we feed one position at a time,
|
337 |
+
# and cache the keys and values step by step.
|
338 |
+
|
339 |
+
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
|
340 |
+
key_states, value_states, attention_mask = self._concatenate_to_cache(
|
341 |
+
key_states, value_states, query_states, attention_mask
|
342 |
+
)
|
343 |
+
|
344 |
+
# Convert the boolean attention mask to an attention bias.
|
345 |
+
if attention_mask is not None:
|
346 |
+
# attention mask in the form of attention bias
|
347 |
+
attention_bias = lax.select(
|
348 |
+
attention_mask > 0,
|
349 |
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
350 |
+
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
351 |
+
)
|
352 |
+
else:
|
353 |
+
attention_bias = None
|
354 |
+
|
355 |
+
dropout_rng = None
|
356 |
+
if not deterministic and self.dropout > 0.0:
|
357 |
+
dropout_rng = self.make_rng("dropout")
|
358 |
+
|
359 |
+
attn_weights = dot_product_attention_weights(
|
360 |
+
query_states,
|
361 |
+
key_states,
|
362 |
+
bias=attention_bias,
|
363 |
+
dropout_rng=dropout_rng,
|
364 |
+
dropout_rate=self.dropout,
|
365 |
+
broadcast_dropout=True,
|
366 |
+
deterministic=deterministic,
|
367 |
+
dtype=self.dtype,
|
368 |
+
precision=None,
|
369 |
+
)
|
370 |
+
|
371 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
372 |
+
attn_output = self._merge_heads(attn_output)
|
373 |
+
attn_output = self.out_proj(attn_output)
|
374 |
+
|
375 |
+
return attn_output, attn_weights
|
376 |
+
|
377 |
+
def _split_heads(self, hidden_state) -> jnp.ndarray:
|
378 |
+
return hidden_state.reshape(hidden_state.shape[:2] + (self.num_heads, self.head_dim))
|
379 |
+
|
380 |
+
def _merge_heads(self, hidden_state) -> jnp.ndarray:
|
381 |
+
return hidden_state.reshape(hidden_state.shape[:2] + (self.embed_dim,))
|
382 |
+
|
383 |
+
@nn.compact
|
384 |
+
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
385 |
+
# The following code is largely copied from: https://github.com/google-research/t5x/blob/63d9addf628c6d8c547a407a32095fcb527bb20b/t5x/examples/scalable_t5/layers.py#L280-L284
|
386 |
+
is_initialized = self.has_variable("cache", "cached_key")
|
387 |
+
|
388 |
+
# The key and value have dimension [batch_size, seq_length, num_heads, head_dim],
|
389 |
+
# but we cache them as [batch_size, num_heads, head_dim, seq_length] as a TPU
|
390 |
+
# fusion optimization. This also enables the "scatter via one-hot
|
391 |
+
# broadcast" trick, which means we do a one-hot broadcast instead of a
|
392 |
+
# scatter/gather operations, resulting in a 3-4x speedup in practice.
|
393 |
+
def swap_dims(x):
|
394 |
+
return x[:-3] + tuple(x[i] for i in [-2, -1, -3])
|
395 |
+
|
396 |
+
cached_key = self.variable("cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype)
|
397 |
+
cached_value = self.variable("cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype)
|
398 |
+
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
399 |
+
|
400 |
+
if is_initialized:
|
401 |
+
batch_size, num_heads, head_dim, seq_length = cached_key.value.shape
|
402 |
+
# During fast autoregressive decoding, we feed one position at a time,
|
403 |
+
# and cache the keys and values step by step.
|
404 |
+
# Sanity shape check of cached key against input query.
|
405 |
+
num_updated_cache_vectors = query.shape[1]
|
406 |
+
expected_shape = (batch_size, 1, num_heads, head_dim)
|
407 |
+
if num_updated_cache_vectors == 1 and expected_shape != query.shape:
|
408 |
+
raise ValueError(
|
409 |
+
"Autoregressive cache shape error, expected query shape"
|
410 |
+
f" {expected_shape} instead got {query.shape}"
|
411 |
+
)
|
412 |
+
|
413 |
+
# Create a OHE of the current index. NOTE: the index is increased below.
|
414 |
+
cur_index = cache_index.value
|
415 |
+
|
416 |
+
# In order to update the key, value caches with the current key and
|
417 |
+
# value, we move the seq_length axis to the back, similar to what we did for
|
418 |
+
# the cached ones above.
|
419 |
+
# Note these are currently the key and value of a single position, since
|
420 |
+
# we feed one position at a time.
|
421 |
+
one_token_key = jnp.moveaxis(key, -3, -1)
|
422 |
+
one_token_value = jnp.moveaxis(value, -3, -1)
|
423 |
+
|
424 |
+
# Update key, value caches with our new 1d spatial slices.
|
425 |
+
# We implement an efficient scatter into the cache via one-hot
|
426 |
+
# broadcast and addition.
|
427 |
+
if num_updated_cache_vectors > 1:
|
428 |
+
indices = jnp.eye(num_updated_cache_vectors, seq_length)[None, None]
|
429 |
+
key = cached_key.value + jnp.matmul(one_token_key, indices)
|
430 |
+
value = cached_value.value + jnp.matmul(one_token_value, indices)
|
431 |
+
else:
|
432 |
+
one_hot_indices = jax.nn.one_hot(cur_index, seq_length, dtype=key.dtype)
|
433 |
+
key = cached_key.value + one_token_key * one_hot_indices
|
434 |
+
value = cached_value.value + one_token_value * one_hot_indices
|
435 |
+
|
436 |
+
cached_key.value = key
|
437 |
+
cached_value.value = value
|
438 |
+
cache_index.value = cache_index.value + num_updated_cache_vectors
|
439 |
+
|
440 |
+
# Move the keys and values back to their original shapes.
|
441 |
+
key = jnp.moveaxis(key, -1, -3)
|
442 |
+
value = jnp.moveaxis(value, -1, -3)
|
443 |
+
|
444 |
+
# causal mask for cached decoder self-attention: our single query position should only
|
445 |
+
# attend to those key positions that have already been generated and cached, not the
|
446 |
+
# remaining zero elements.
|
447 |
+
pad_mask = jnp.broadcast_to(
|
448 |
+
jnp.arange(seq_length) < cur_index + num_updated_cache_vectors,
|
449 |
+
(batch_size,) + (1, num_updated_cache_vectors, seq_length),
|
450 |
+
)
|
451 |
+
attention_mask = combine_masks(pad_mask, attention_mask)
|
452 |
+
|
453 |
+
return key, value, attention_mask
|
454 |
+
|
455 |
+
|
456 |
+
class FlaxWhisperEncoderLayer(nn.Module):
|
457 |
+
config: WhisperConfig
|
458 |
+
dtype: jnp.dtype = jnp.float32
|
459 |
+
params_dtype: jnp.dtype = jnp.float32
|
460 |
+
use_scan: bool = False
|
461 |
+
|
462 |
+
def setup(self) -> None:
|
463 |
+
self.embed_dim = self.config.d_model
|
464 |
+
self.self_attn = FlaxWhisperAttention(
|
465 |
+
config=self.config,
|
466 |
+
embed_dim=self.embed_dim,
|
467 |
+
num_heads=self.config.encoder_attention_heads,
|
468 |
+
dropout=self.config.attention_dropout,
|
469 |
+
dtype=self.dtype,
|
470 |
+
params_dtype=self.params_dtype,
|
471 |
+
)
|
472 |
+
self.self_attn_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
|
473 |
+
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
474 |
+
self.activation_fn = ACT2FN[self.config.activation_function]
|
475 |
+
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
|
476 |
+
self.fc1 = DenseGeneral(
|
477 |
+
self.config.encoder_ffn_dim,
|
478 |
+
dtype=self.dtype,
|
479 |
+
params_dtype=self.params_dtype,
|
480 |
+
kernel_axes=("embed", "mlp"),
|
481 |
+
)
|
482 |
+
self.fc2 = DenseGeneral(
|
483 |
+
self.embed_dim,
|
484 |
+
dtype=self.dtype,
|
485 |
+
params_dtype=self.params_dtype,
|
486 |
+
kernel_axes=("mlp", "embed"),
|
487 |
+
)
|
488 |
+
self.final_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
|
489 |
+
|
490 |
+
def __call__(
|
491 |
+
self,
|
492 |
+
hidden_states: jnp.ndarray,
|
493 |
+
attention_mask: jnp.ndarray,
|
494 |
+
output_attentions: bool = True,
|
495 |
+
deterministic: bool = True,
|
496 |
+
all_hidden_states=None, # only used when `use_scan=True` -> we have to fetch the hidden states from within the layer
|
497 |
+
) -> Tuple[jnp.ndarray]:
|
498 |
+
if self.use_scan:
|
499 |
+
hidden_states = hidden_states[0]
|
500 |
+
|
501 |
+
hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
|
502 |
+
|
503 |
+
residual = hidden_states
|
504 |
+
|
505 |
+
layernorm_output = self.self_attn_layer_norm(hidden_states)
|
506 |
+
layernorm_output = with_sharding_constraint(layernorm_output, ("batch", "length", "embed"))
|
507 |
+
|
508 |
+
attn_output, attn_weights = self.self_attn(hidden_states=layernorm_output, attention_mask=attention_mask)
|
509 |
+
attn_output = self.dropout_layer(attn_output, deterministic=deterministic)
|
510 |
+
attn_output = residual + attn_output
|
511 |
+
attn_output = with_sharding_constraint(attn_output, ("batch", "length", "embed"))
|
512 |
+
|
513 |
+
residual = attn_output
|
514 |
+
|
515 |
+
post_layer_norm = self.final_layer_norm(attn_output)
|
516 |
+
post_layer_norm = with_sharding_constraint(post_layer_norm, ("batch", "length", "embed"))
|
517 |
+
|
518 |
+
fc1_output = self.activation_fn(self.fc1(post_layer_norm))
|
519 |
+
fc1_output = self.activation_dropout_layer(fc1_output, deterministic=deterministic)
|
520 |
+
fc1_output = with_sharding_constraint(fc1_output, ("batch", "length", "mlp"))
|
521 |
+
|
522 |
+
hidden_states = self.fc2(fc1_output)
|
523 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
524 |
+
hidden_states = residual + hidden_states
|
525 |
+
hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
|
526 |
+
|
527 |
+
outputs = (hidden_states,)
|
528 |
+
|
529 |
+
if output_attentions:
|
530 |
+
outputs += (attn_weights,)
|
531 |
+
|
532 |
+
if self.use_scan:
|
533 |
+
if all_hidden_states is not None:
|
534 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
535 |
+
outputs = (
|
536 |
+
outputs,
|
537 |
+
all_hidden_states,
|
538 |
+
)
|
539 |
+
|
540 |
+
return outputs
|
541 |
+
|
542 |
+
|
543 |
+
class FlaxWhisperEncoderLayerCollection(nn.Module):
|
544 |
+
config: WhisperConfig
|
545 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
546 |
+
params_dtype: jnp.dtype = jnp.float32
|
547 |
+
use_scan: bool = False
|
548 |
+
gradient_checkpointing: bool = False
|
549 |
+
|
550 |
+
@nn.compact
|
551 |
+
def __call__(
|
552 |
+
self,
|
553 |
+
hidden_states,
|
554 |
+
attention_mask,
|
555 |
+
deterministic: bool = True,
|
556 |
+
output_attentions: bool = False,
|
557 |
+
output_hidden_states: bool = False,
|
558 |
+
return_dict: bool = True,
|
559 |
+
):
|
560 |
+
all_attentions = () if output_attentions else None
|
561 |
+
all_hidden_states = () if output_hidden_states else None
|
562 |
+
|
563 |
+
FlaxWhisperEncoderCheckpointLayer = (
|
564 |
+
remat(
|
565 |
+
FlaxWhisperEncoderLayer,
|
566 |
+
static_argnums=(2, 3),
|
567 |
+
prevent_cse=not self.use_scan,
|
568 |
+
)
|
569 |
+
if self.gradient_checkpointing
|
570 |
+
else FlaxWhisperEncoderLayer
|
571 |
+
)
|
572 |
+
|
573 |
+
if self.use_scan:
|
574 |
+
if output_attentions:
|
575 |
+
raise ValueError("Cannot use `scan` with `output_attentions` set to True")
|
576 |
+
|
577 |
+
# nicest behaviour for scan is to let the compiler figure out the correct shapes for the hidden states
|
578 |
+
# so we'll just pass an empty tuple as the carry initializer and hold on to the first hidden states for later
|
579 |
+
input_hidden_states = hidden_states
|
580 |
+
hidden_states = (hidden_states,)
|
581 |
+
|
582 |
+
hidden_states, all_hidden_states = scan_with_axes(
|
583 |
+
FlaxWhisperEncoderCheckpointLayer,
|
584 |
+
variable_axes={"params": 0, "cache": 0},
|
585 |
+
split_rngs={"params": True, "dropout": True},
|
586 |
+
in_axes=(
|
587 |
+
nn.broadcast,
|
588 |
+
nn.broadcast,
|
589 |
+
nn.broadcast,
|
590 |
+
nn.broadcast,
|
591 |
+
),
|
592 |
+
variable_carry="all_hidden_states",
|
593 |
+
length=self.config.encoder_layers,
|
594 |
+
)(
|
595 |
+
self.config,
|
596 |
+
dtype=self.dtype,
|
597 |
+
params_dtype=self.params_dtype,
|
598 |
+
use_scan=True,
|
599 |
+
name="FlaxEncoderScanLayers",
|
600 |
+
)(
|
601 |
+
hidden_states,
|
602 |
+
attention_mask,
|
603 |
+
output_attentions,
|
604 |
+
deterministic,
|
605 |
+
all_hidden_states, # tuple intializer (or None if not using output_hidden_states)
|
606 |
+
)
|
607 |
+
|
608 |
+
# remove the scan dimension
|
609 |
+
hidden_states = hidden_states[0]
|
610 |
+
|
611 |
+
if output_hidden_states:
|
612 |
+
# if we're using scan we'll surely be training -> return hidden states as a tensor rather than tuple
|
613 |
+
all_hidden_states = jnp.vstack([input_hidden_states[None, ...], all_hidden_states[0]])
|
614 |
+
|
615 |
+
else:
|
616 |
+
for layer_idx in range(self.config.encoder_layers):
|
617 |
+
if output_hidden_states:
|
618 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
619 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
620 |
+
dropout_probability = random.uniform(0, 1)
|
621 |
+
if not deterministic and (dropout_probability < self.config.encoder_layerdrop): # skip the layer
|
622 |
+
layer_outputs = (None, None)
|
623 |
+
else:
|
624 |
+
layer_outputs = FlaxWhisperEncoderCheckpointLayer(
|
625 |
+
self.config,
|
626 |
+
dtype=self.dtype,
|
627 |
+
params_dtype=self.params_dtype,
|
628 |
+
name=str(layer_idx),
|
629 |
+
)(
|
630 |
+
hidden_states,
|
631 |
+
attention_mask,
|
632 |
+
output_attentions,
|
633 |
+
deterministic,
|
634 |
+
)
|
635 |
+
hidden_states = layer_outputs[0]
|
636 |
+
if output_attentions:
|
637 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
638 |
+
|
639 |
+
if output_hidden_states:
|
640 |
+
all_hidden_states += (hidden_states,)
|
641 |
+
|
642 |
+
outputs = (hidden_states, all_hidden_states, all_attentions)
|
643 |
+
|
644 |
+
if not return_dict:
|
645 |
+
return tuple(v for v in outputs if v is not None)
|
646 |
+
|
647 |
+
return FlaxBaseModelOutput(
|
648 |
+
last_hidden_state=hidden_states,
|
649 |
+
hidden_states=all_hidden_states,
|
650 |
+
attentions=all_attentions,
|
651 |
+
)
|
652 |
+
|
653 |
+
|
654 |
+
class FlaxWhisperDecoderLayer(nn.Module):
|
655 |
+
config: WhisperConfig
|
656 |
+
dtype: jnp.dtype = jnp.float32
|
657 |
+
params_dtype: jnp.dtype = jnp.float32
|
658 |
+
use_scan: bool = False
|
659 |
+
|
660 |
+
def setup(self) -> None:
|
661 |
+
self.embed_dim = self.config.d_model
|
662 |
+
self.self_attn = FlaxWhisperAttention(
|
663 |
+
config=self.config,
|
664 |
+
embed_dim=self.embed_dim,
|
665 |
+
num_heads=self.config.decoder_attention_heads,
|
666 |
+
dropout=self.config.attention_dropout,
|
667 |
+
causal=True,
|
668 |
+
dtype=self.dtype,
|
669 |
+
params_dtype=self.params_dtype,
|
670 |
+
)
|
671 |
+
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
672 |
+
self.activation_fn = ACT2FN[self.config.activation_function]
|
673 |
+
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
|
674 |
+
|
675 |
+
self.self_attn_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
|
676 |
+
self.encoder_attn = FlaxWhisperAttention(
|
677 |
+
config=self.config,
|
678 |
+
embed_dim=self.embed_dim,
|
679 |
+
num_heads=self.config.decoder_attention_heads,
|
680 |
+
dropout=self.config.attention_dropout,
|
681 |
+
dtype=self.dtype,
|
682 |
+
params_dtype=self.params_dtype,
|
683 |
+
)
|
684 |
+
self.encoder_attn_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
|
685 |
+
self.fc1 = DenseGeneral(
|
686 |
+
self.config.decoder_ffn_dim,
|
687 |
+
dtype=self.dtype,
|
688 |
+
params_dtype=self.params_dtype,
|
689 |
+
kernel_axes=("embed", "mlp"),
|
690 |
+
)
|
691 |
+
self.fc2 = DenseGeneral(
|
692 |
+
self.embed_dim,
|
693 |
+
dtype=self.dtype,
|
694 |
+
params_dtype=self.params_dtype,
|
695 |
+
kernel_axes=("mlp", "embed"),
|
696 |
+
)
|
697 |
+
self.final_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
|
698 |
+
|
699 |
+
def __call__(
|
700 |
+
self,
|
701 |
+
hidden_states: jnp.ndarray,
|
702 |
+
attention_mask: jnp.ndarray,
|
703 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
704 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
705 |
+
init_cache: bool = False,
|
706 |
+
output_attentions: bool = True,
|
707 |
+
deterministic: bool = True,
|
708 |
+
all_hidden_states=None, # only used when `use_scan=True` -> we have to fetch the hidden states from within the layer
|
709 |
+
) -> Tuple[jnp.ndarray]:
|
710 |
+
if self.use_scan:
|
711 |
+
hidden_states = hidden_states[0]
|
712 |
+
|
713 |
+
hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
|
714 |
+
|
715 |
+
residual = hidden_states
|
716 |
+
|
717 |
+
layer_norm_output = self.self_attn_layer_norm(hidden_states)
|
718 |
+
layer_norm_output = with_sharding_constraint(layer_norm_output, ("batch", "length", "embed"))
|
719 |
+
|
720 |
+
# Self Attention
|
721 |
+
self_attn_output, self_attn_weights = self.self_attn(
|
722 |
+
hidden_states=layer_norm_output,
|
723 |
+
attention_mask=attention_mask,
|
724 |
+
init_cache=init_cache,
|
725 |
+
)
|
726 |
+
self_attn_output = self.dropout_layer(self_attn_output, deterministic=deterministic)
|
727 |
+
self_attn_output = residual + self_attn_output
|
728 |
+
self_attn_output = with_sharding_constraint(self_attn_output, ("batch", "length", "embed"))
|
729 |
+
|
730 |
+
# Cross-Attention Block
|
731 |
+
cross_attn_weights = None
|
732 |
+
if encoder_hidden_states is not None:
|
733 |
+
residual = self_attn_output
|
734 |
+
|
735 |
+
encoder_layer_norm_output = self.encoder_attn_layer_norm(self_attn_output)
|
736 |
+
encoder_layer_norm_output = with_sharding_constraint(
|
737 |
+
encoder_layer_norm_output, ("batch", "length", "embed")
|
738 |
+
)
|
739 |
+
|
740 |
+
cross_attn_output, cross_attn_weights = self.encoder_attn(
|
741 |
+
hidden_states=encoder_layer_norm_output,
|
742 |
+
key_value_states=encoder_hidden_states,
|
743 |
+
attention_mask=encoder_attention_mask,
|
744 |
+
)
|
745 |
+
cross_attn_output = self.dropout_layer(cross_attn_output, deterministic=deterministic)
|
746 |
+
cross_attn_output = residual + cross_attn_output
|
747 |
+
cross_attn_output = with_sharding_constraint(cross_attn_output, ("batch", "length", "embed"))
|
748 |
+
|
749 |
+
# Fully Connected
|
750 |
+
residual = cross_attn_output
|
751 |
+
|
752 |
+
post_layer_norm = self.final_layer_norm(cross_attn_output)
|
753 |
+
post_layer_norm = with_sharding_constraint(post_layer_norm, ("batch", "length", "embed"))
|
754 |
+
|
755 |
+
fc1_output = self.activation_fn(self.fc1(post_layer_norm))
|
756 |
+
fc1_output = self.activation_dropout_layer(fc1_output, deterministic=deterministic)
|
757 |
+
fc1_output = with_sharding_constraint(fc1_output, ("batch", "length", "mlp"))
|
758 |
+
|
759 |
+
hidden_states = self.fc2(fc1_output)
|
760 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
761 |
+
hidden_states = residual + hidden_states
|
762 |
+
hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
|
763 |
+
|
764 |
+
outputs = (hidden_states,)
|
765 |
+
|
766 |
+
if output_attentions:
|
767 |
+
outputs += (self_attn_weights, cross_attn_weights)
|
768 |
+
|
769 |
+
if self.use_scan:
|
770 |
+
if all_hidden_states is not None:
|
771 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
772 |
+
outputs = (
|
773 |
+
outputs,
|
774 |
+
all_hidden_states,
|
775 |
+
)
|
776 |
+
|
777 |
+
return outputs
|
778 |
+
|
779 |
+
|
780 |
+
class FlaxWhisperDecoderLayerCollection(nn.Module):
|
781 |
+
config: WhisperConfig
|
782 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
783 |
+
params_dtype: jnp.dtype = jnp.float32
|
784 |
+
use_scan: bool = False
|
785 |
+
gradient_checkpointing: bool = False
|
786 |
+
|
787 |
+
@nn.compact
|
788 |
+
def __call__(
|
789 |
+
self,
|
790 |
+
hidden_states,
|
791 |
+
attention_mask,
|
792 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
793 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
794 |
+
deterministic: bool = True,
|
795 |
+
init_cache: bool = False,
|
796 |
+
output_attentions: bool = False,
|
797 |
+
output_hidden_states: bool = False,
|
798 |
+
return_dict: bool = True,
|
799 |
+
):
|
800 |
+
# decoder layers
|
801 |
+
all_hidden_states = () if output_hidden_states else None
|
802 |
+
all_self_attns = () if output_attentions else None
|
803 |
+
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
804 |
+
|
805 |
+
FlaxWhisperDecoderCheckpointLayer = (
|
806 |
+
remat(
|
807 |
+
FlaxWhisperDecoderLayer,
|
808 |
+
static_argnums=(4, 5, 6),
|
809 |
+
prevent_cse=not self.use_scan,
|
810 |
+
)
|
811 |
+
if self.gradient_checkpointing
|
812 |
+
else FlaxWhisperDecoderLayer
|
813 |
+
)
|
814 |
+
|
815 |
+
if self.use_scan:
|
816 |
+
if output_attentions:
|
817 |
+
raise ValueError("Cannot use `scan` with `output_attentions` set to True")
|
818 |
+
|
819 |
+
input_hidden_states = hidden_states
|
820 |
+
hidden_states = (hidden_states,)
|
821 |
+
|
822 |
+
hidden_states, all_hidden_states = scan_with_axes(
|
823 |
+
FlaxWhisperDecoderCheckpointLayer,
|
824 |
+
variable_axes={"params": 0, "cache": 0},
|
825 |
+
split_rngs={"params": True, "dropout": True},
|
826 |
+
in_axes=(
|
827 |
+
nn.broadcast,
|
828 |
+
nn.broadcast,
|
829 |
+
nn.broadcast,
|
830 |
+
nn.broadcast,
|
831 |
+
nn.broadcast,
|
832 |
+
nn.broadcast,
|
833 |
+
nn.broadcast,
|
834 |
+
),
|
835 |
+
variable_carry="all_hidden_states",
|
836 |
+
length=self.config.decoder_layers,
|
837 |
+
)(
|
838 |
+
self.config,
|
839 |
+
dtype=self.dtype,
|
840 |
+
params_dtype=self.params_dtype,
|
841 |
+
use_scan=True,
|
842 |
+
name="FlaxDecoderScanLayers",
|
843 |
+
)(
|
844 |
+
hidden_states,
|
845 |
+
attention_mask,
|
846 |
+
encoder_hidden_states,
|
847 |
+
encoder_attention_mask,
|
848 |
+
init_cache,
|
849 |
+
output_attentions,
|
850 |
+
deterministic,
|
851 |
+
all_hidden_states,
|
852 |
+
)
|
853 |
+
hidden_states = hidden_states[0]
|
854 |
+
|
855 |
+
if output_hidden_states:
|
856 |
+
# if we're using scan we'll surely be training -> return hidden states as a tensor rather than tuple
|
857 |
+
all_hidden_states = jnp.vstack([input_hidden_states[None, ...], all_hidden_states[0]])
|
858 |
+
|
859 |
+
else:
|
860 |
+
for layer_idx in range(self.config.decoder_layers):
|
861 |
+
if output_hidden_states:
|
862 |
+
all_hidden_states += (hidden_states,)
|
863 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
864 |
+
dropout_probability = random.uniform(0, 1)
|
865 |
+
if not deterministic and (dropout_probability < self.config.decoder_layerdrop):
|
866 |
+
layer_outputs = (None, None, None)
|
867 |
+
else:
|
868 |
+
layer_outputs = FlaxWhisperDecoderCheckpointLayer(
|
869 |
+
self.config,
|
870 |
+
dtype=self.dtype,
|
871 |
+
params_dtype=self.params_dtype,
|
872 |
+
name=str(layer_idx),
|
873 |
+
)(
|
874 |
+
hidden_states,
|
875 |
+
attention_mask,
|
876 |
+
encoder_hidden_states,
|
877 |
+
encoder_attention_mask,
|
878 |
+
init_cache,
|
879 |
+
output_attentions,
|
880 |
+
deterministic,
|
881 |
+
)
|
882 |
+
|
883 |
+
hidden_states = layer_outputs[0]
|
884 |
+
if output_attentions:
|
885 |
+
all_self_attns += (layer_outputs[1],)
|
886 |
+
|
887 |
+
if encoder_hidden_states is not None:
|
888 |
+
all_cross_attentions += (layer_outputs[2],)
|
889 |
+
|
890 |
+
# add hidden states from the last decoder layer
|
891 |
+
if output_hidden_states:
|
892 |
+
all_hidden_states += (hidden_states,)
|
893 |
+
|
894 |
+
outputs = [
|
895 |
+
hidden_states,
|
896 |
+
all_hidden_states,
|
897 |
+
all_self_attns,
|
898 |
+
all_cross_attentions,
|
899 |
+
]
|
900 |
+
|
901 |
+
if not return_dict:
|
902 |
+
return tuple(v for v in outputs if v is not None)
|
903 |
+
|
904 |
+
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
905 |
+
last_hidden_state=hidden_states,
|
906 |
+
hidden_states=all_hidden_states,
|
907 |
+
attentions=all_self_attns,
|
908 |
+
cross_attentions=all_cross_attentions,
|
909 |
+
)
|
910 |
+
|
911 |
+
|
912 |
+
class FlaxWhisperEncoder(nn.Module):
|
913 |
+
config: WhisperConfig
|
914 |
+
dtype: jnp.dtype = jnp.float32
|
915 |
+
params_dtype: jnp.dtype = jnp.float32
|
916 |
+
use_scan: bool = False
|
917 |
+
gradient_checkpointing: bool = False
|
918 |
+
|
919 |
+
def setup(self) -> None:
|
920 |
+
self.conv1 = Conv(
|
921 |
+
self.config.d_model,
|
922 |
+
kernel_size=(3,),
|
923 |
+
padding=1,
|
924 |
+
dtype=self.dtype,
|
925 |
+
params_dtype=self.params_dtype,
|
926 |
+
kernel_axes=("channels", "num_mel", "embed"),
|
927 |
+
)
|
928 |
+
self.conv2 = Conv(
|
929 |
+
self.config.d_model,
|
930 |
+
kernel_size=(3,),
|
931 |
+
strides=2,
|
932 |
+
padding=1,
|
933 |
+
dtype=self.dtype,
|
934 |
+
params_dtype=self.params_dtype,
|
935 |
+
kernel_axes=("channels", "embed", "num_mel"),
|
936 |
+
)
|
937 |
+
|
938 |
+
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
939 |
+
|
940 |
+
self.layers = FlaxWhisperEncoderLayerCollection(
|
941 |
+
self.config,
|
942 |
+
dtype=self.dtype,
|
943 |
+
params_dtype=self.params_dtype,
|
944 |
+
use_scan=self.use_scan,
|
945 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
946 |
+
)
|
947 |
+
self.embed_positions = Embed(
|
948 |
+
self.config.max_source_positions,
|
949 |
+
self.config.d_model,
|
950 |
+
dtype=self.dtype,
|
951 |
+
params_dtype=self.params_dtype,
|
952 |
+
)
|
953 |
+
|
954 |
+
self.layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
|
955 |
+
|
956 |
+
def __call__(
|
957 |
+
self,
|
958 |
+
input_features: jnp.ndarray,
|
959 |
+
output_attentions: bool = False,
|
960 |
+
output_hidden_states: bool = False,
|
961 |
+
return_dict: bool = True,
|
962 |
+
deterministic: bool = True,
|
963 |
+
) -> Tuple[jnp.ndarray]:
|
964 |
+
if input_features.shape[1:] != (
|
965 |
+
self.config.num_mel_bins,
|
966 |
+
self.config.max_source_positions * 2,
|
967 |
+
):
|
968 |
+
raise ValueError(
|
969 |
+
"input_features.shape[1:], must be equal to (self.config.num_mel_bins,"
|
970 |
+
" self.config.max_source_positions * 2) (got"
|
971 |
+
f" {input_features.shape[1:]}, but should be"
|
972 |
+
f" ({self.config.num_mel_bins},"
|
973 |
+
f" {self.config.max_source_positions * 2}))"
|
974 |
+
)
|
975 |
+
|
976 |
+
input_features = input_features.transpose(0, 2, 1)
|
977 |
+
hidden_states = jax.nn.gelu(self.conv1(input_features), approximate=False)
|
978 |
+
hidden_states = with_sharding_constraint(hidden_states, ("batch", "embed", "num_mel"))
|
979 |
+
hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False)
|
980 |
+
hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
|
981 |
+
|
982 |
+
embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions))
|
983 |
+
# sinusoidal positional embeddings should not be trained
|
984 |
+
embed_positions = jax.lax.stop_gradient(embed_positions)
|
985 |
+
hidden_states = hidden_states + embed_positions
|
986 |
+
|
987 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
988 |
+
|
989 |
+
outputs = self.layers(
|
990 |
+
hidden_states,
|
991 |
+
attention_mask=None,
|
992 |
+
deterministic=deterministic,
|
993 |
+
output_attentions=output_attentions,
|
994 |
+
output_hidden_states=output_hidden_states,
|
995 |
+
return_dict=return_dict,
|
996 |
+
)
|
997 |
+
|
998 |
+
last_hidden_states = outputs[0]
|
999 |
+
last_hidden_states = self.layer_norm(last_hidden_states)
|
1000 |
+
|
1001 |
+
# update the last element in `hidden_states` after applying `layernorm` above
|
1002 |
+
hidden_states = None
|
1003 |
+
if output_hidden_states:
|
1004 |
+
hidden_states = outputs[1]
|
1005 |
+
if self.use_scan:
|
1006 |
+
hidden_states = jnp.vstack([hidden_states[:-1], last_hidden_states[None, ...]])
|
1007 |
+
else:
|
1008 |
+
hidden_states = hidden_states[:-1] + (last_hidden_states,)
|
1009 |
+
|
1010 |
+
if not return_dict:
|
1011 |
+
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
|
1012 |
+
return tuple(v for v in outputs if v is not None)
|
1013 |
+
|
1014 |
+
return FlaxBaseModelOutput(
|
1015 |
+
last_hidden_state=last_hidden_states,
|
1016 |
+
hidden_states=hidden_states,
|
1017 |
+
attentions=outputs.attentions,
|
1018 |
+
)
|
1019 |
+
|
1020 |
+
|
1021 |
+
class FlaxWhisperDecoder(nn.Module):
|
1022 |
+
config: WhisperConfig
|
1023 |
+
dtype: jnp.dtype = jnp.float32
|
1024 |
+
params_dtype: jnp.dtype = jnp.float32
|
1025 |
+
use_scan: bool = False
|
1026 |
+
gradient_checkpointing: bool = False
|
1027 |
+
|
1028 |
+
def setup(self) -> None:
|
1029 |
+
self.embed_tokens = Embed(
|
1030 |
+
self.config.vocab_size,
|
1031 |
+
self.config.d_model,
|
1032 |
+
dtype=self.dtype,
|
1033 |
+
params_dtype=self.params_dtype,
|
1034 |
+
)
|
1035 |
+
self.embed_positions = Embed(
|
1036 |
+
self.config.max_target_positions,
|
1037 |
+
self.config.d_model,
|
1038 |
+
dtype=self.dtype,
|
1039 |
+
params_dtype=self.params_dtype,
|
1040 |
+
)
|
1041 |
+
|
1042 |
+
self.layers = FlaxWhisperDecoderLayerCollection(
|
1043 |
+
self.config,
|
1044 |
+
dtype=self.dtype,
|
1045 |
+
params_dtype=self.params_dtype,
|
1046 |
+
use_scan=self.use_scan,
|
1047 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
1048 |
+
)
|
1049 |
+
|
1050 |
+
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
1051 |
+
|
1052 |
+
self.layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-5, params_dtype=self.params_dtype)
|
1053 |
+
|
1054 |
+
def __call__(
|
1055 |
+
self,
|
1056 |
+
input_ids: jnp.ndarray,
|
1057 |
+
attention_mask: jnp.ndarray,
|
1058 |
+
position_ids: jnp.ndarray,
|
1059 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
1060 |
+
init_cache: bool = False,
|
1061 |
+
output_attentions: bool = False,
|
1062 |
+
output_hidden_states: bool = False,
|
1063 |
+
return_dict: bool = True,
|
1064 |
+
deterministic: bool = True,
|
1065 |
+
) -> Tuple[jnp.ndarray]:
|
1066 |
+
input_embeds = self.embed_tokens(input_ids)
|
1067 |
+
position_embeds = self.embed_positions(position_ids)
|
1068 |
+
|
1069 |
+
hidden_states = input_embeds + position_embeds
|
1070 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
1071 |
+
|
1072 |
+
outputs = self.layers(
|
1073 |
+
hidden_states,
|
1074 |
+
attention_mask=attention_mask,
|
1075 |
+
encoder_hidden_states=encoder_hidden_states,
|
1076 |
+
deterministic=deterministic,
|
1077 |
+
init_cache=init_cache,
|
1078 |
+
output_attentions=output_attentions,
|
1079 |
+
output_hidden_states=output_hidden_states,
|
1080 |
+
return_dict=return_dict,
|
1081 |
+
)
|
1082 |
+
|
1083 |
+
last_hidden_states = outputs[0]
|
1084 |
+
last_hidden_states = self.layer_norm(last_hidden_states)
|
1085 |
+
|
1086 |
+
# update the last element in `hidden_states` after applying `layernorm` above
|
1087 |
+
hidden_states = None
|
1088 |
+
if output_hidden_states:
|
1089 |
+
hidden_states = outputs[1]
|
1090 |
+
if self.use_scan:
|
1091 |
+
hidden_states = jnp.vstack([hidden_states[:-1], last_hidden_states[None, ...]])
|
1092 |
+
else:
|
1093 |
+
hidden_states = hidden_states[:-1] + (last_hidden_states,)
|
1094 |
+
|
1095 |
+
if not return_dict:
|
1096 |
+
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
|
1097 |
+
return tuple(v for v in outputs if v is not None)
|
1098 |
+
|
1099 |
+
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
1100 |
+
last_hidden_state=last_hidden_states,
|
1101 |
+
hidden_states=hidden_states,
|
1102 |
+
attentions=outputs.attentions,
|
1103 |
+
cross_attentions=outputs.cross_attentions,
|
1104 |
+
)
|
1105 |
+
|
1106 |
+
|
1107 |
+
class FlaxWhisperModule(nn.Module):
|
1108 |
+
config: WhisperConfig
|
1109 |
+
dtype: jnp.dtype = jnp.float32
|
1110 |
+
params_dtype: jnp.dtype = jnp.float32
|
1111 |
+
use_scan: bool = False
|
1112 |
+
gradient_checkpointing: bool = False
|
1113 |
+
|
1114 |
+
def setup(self) -> None:
|
1115 |
+
self.encoder = FlaxWhisperEncoder(
|
1116 |
+
self.config,
|
1117 |
+
dtype=self.dtype,
|
1118 |
+
params_dtype=self.params_dtype,
|
1119 |
+
use_scan=self.use_scan,
|
1120 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
1121 |
+
)
|
1122 |
+
self.decoder = FlaxWhisperDecoder(
|
1123 |
+
self.config,
|
1124 |
+
dtype=self.dtype,
|
1125 |
+
params_dtype=self.params_dtype,
|
1126 |
+
use_scan=self.use_scan,
|
1127 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
1128 |
+
)
|
1129 |
+
|
1130 |
+
def __call__(
|
1131 |
+
self,
|
1132 |
+
input_features: jnp.ndarray,
|
1133 |
+
decoder_input_ids: jnp.ndarray,
|
1134 |
+
decoder_attention_mask: jnp.ndarray,
|
1135 |
+
decoder_position_ids: jnp.ndarray,
|
1136 |
+
output_attentions: bool = False,
|
1137 |
+
output_hidden_states: bool = False,
|
1138 |
+
freeze_encoder: bool = False,
|
1139 |
+
return_dict: bool = True,
|
1140 |
+
deterministic: bool = True,
|
1141 |
+
):
|
1142 |
+
encoder_outputs = self.encoder(
|
1143 |
+
input_features,
|
1144 |
+
output_attentions=output_attentions,
|
1145 |
+
output_hidden_states=output_hidden_states,
|
1146 |
+
return_dict=return_dict,
|
1147 |
+
deterministic=deterministic,
|
1148 |
+
)
|
1149 |
+
|
1150 |
+
encoder_hidden_states = encoder_outputs[0]
|
1151 |
+
|
1152 |
+
if freeze_encoder:
|
1153 |
+
encoder_hidden_states = jax.lax.stop_gradient(encoder_hidden_states)
|
1154 |
+
|
1155 |
+
decoder_outputs = self.decoder(
|
1156 |
+
input_ids=decoder_input_ids,
|
1157 |
+
attention_mask=decoder_attention_mask,
|
1158 |
+
position_ids=decoder_position_ids,
|
1159 |
+
encoder_hidden_states=encoder_hidden_states,
|
1160 |
+
output_attentions=output_attentions,
|
1161 |
+
output_hidden_states=output_hidden_states,
|
1162 |
+
return_dict=return_dict,
|
1163 |
+
deterministic=deterministic,
|
1164 |
+
)
|
1165 |
+
|
1166 |
+
if not return_dict:
|
1167 |
+
return decoder_outputs + encoder_outputs
|
1168 |
+
|
1169 |
+
return FlaxSeq2SeqModelOutput(
|
1170 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
1171 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
1172 |
+
decoder_attentions=decoder_outputs.attentions,
|
1173 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
1174 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
1175 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
1176 |
+
encoder_attentions=encoder_outputs.attentions,
|
1177 |
+
)
|
1178 |
+
|
1179 |
+
def _get_encoder_module(self):
|
1180 |
+
return self.encoder
|
1181 |
+
|
1182 |
+
def _get_decoder_module(self):
|
1183 |
+
return self.decoder
|
1184 |
+
|
1185 |
+
|
1186 |
+
class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
|
1187 |
+
config_class = WhisperConfig
|
1188 |
+
base_model_prefix: str = "model"
|
1189 |
+
main_input_name = "input_features"
|
1190 |
+
module_class: nn.Module = None
|
1191 |
+
|
1192 |
+
def __init__(
|
1193 |
+
self,
|
1194 |
+
config: WhisperConfig,
|
1195 |
+
input_shape: Tuple[int, int, int] = None,
|
1196 |
+
seed: int = 0,
|
1197 |
+
dtype: jnp.dtype = jnp.float32,
|
1198 |
+
params_dtype: jnp.dtype = jnp.float32,
|
1199 |
+
_do_init: bool = True,
|
1200 |
+
# Can only use_scan=True in init if loading scanned weights -> need to handle use_scan=True and unrolled weights
|
1201 |
+
use_scan: bool = False,
|
1202 |
+
gradient_checkpointing: bool = False,
|
1203 |
+
**kwargs,
|
1204 |
+
):
|
1205 |
+
self.use_scan = use_scan
|
1206 |
+
self.gradient_checkpointing = gradient_checkpointing
|
1207 |
+
|
1208 |
+
module = self.module_class(
|
1209 |
+
config=config,
|
1210 |
+
dtype=dtype,
|
1211 |
+
params_dtype=params_dtype,
|
1212 |
+
use_scan=use_scan,
|
1213 |
+
gradient_checkpointing=gradient_checkpointing,
|
1214 |
+
**kwargs,
|
1215 |
+
)
|
1216 |
+
|
1217 |
+
if input_shape is None:
|
1218 |
+
input_shape = (1, config.num_mel_bins, 2 * config.max_source_positions)
|
1219 |
+
|
1220 |
+
super().__init__(
|
1221 |
+
config,
|
1222 |
+
module,
|
1223 |
+
input_shape=input_shape,
|
1224 |
+
seed=seed,
|
1225 |
+
dtype=dtype,
|
1226 |
+
_do_init=_do_init,
|
1227 |
+
)
|
1228 |
+
|
1229 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
1230 |
+
# init input tensors
|
1231 |
+
input_features = jnp.zeros(input_shape, dtype="f4")
|
1232 |
+
input_features = input_features.at[(..., -1)].set(self.config.eos_token_id)
|
1233 |
+
|
1234 |
+
decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4")
|
1235 |
+
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
1236 |
+
|
1237 |
+
batch_size, sequence_length = decoder_input_ids.shape
|
1238 |
+
decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
1239 |
+
|
1240 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
1241 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
1242 |
+
|
1243 |
+
random_params = self.module.init(
|
1244 |
+
rngs,
|
1245 |
+
input_features=input_features,
|
1246 |
+
decoder_input_ids=decoder_input_ids,
|
1247 |
+
decoder_attention_mask=decoder_attention_mask,
|
1248 |
+
decoder_position_ids=decoder_position_ids,
|
1249 |
+
)["params"]
|
1250 |
+
|
1251 |
+
if params is not None:
|
1252 |
+
random_params = flatten_dict(unfreeze(random_params))
|
1253 |
+
params = flatten_dict(unfreeze(params))
|
1254 |
+
for missing_key in self._missing_keys:
|
1255 |
+
params[missing_key] = random_params[missing_key]
|
1256 |
+
self._missing_keys = set()
|
1257 |
+
return freeze(unflatten_dict(params))
|
1258 |
+
else:
|
1259 |
+
return random_params
|
1260 |
+
|
1261 |
+
def enable_gradient_checkpointing(self):
|
1262 |
+
self.gradient_checkpointing = True
|
1263 |
+
self._module = self.module_class(
|
1264 |
+
config=self.config,
|
1265 |
+
dtype=self.dtype,
|
1266 |
+
use_scan=self.use_scan,
|
1267 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
1268 |
+
)
|
1269 |
+
|
1270 |
+
def enable_scan(self):
|
1271 |
+
self.use_scan = True
|
1272 |
+
self._module = self.module_class(
|
1273 |
+
config=self.config,
|
1274 |
+
dtype=self.dtype,
|
1275 |
+
use_scan=self.use_scan,
|
1276 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
1277 |
+
)
|
1278 |
+
init_fn = partial(self.init_weights, input_shape=self.input_shape)
|
1279 |
+
params_shape_tree = jax.eval_shape(init_fn, self.key)
|
1280 |
+
|
1281 |
+
# get the shape of the parameters
|
1282 |
+
self._params_shape_tree = params_shape_tree
|
1283 |
+
|
1284 |
+
# save required_params as set
|
1285 |
+
self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
|
1286 |
+
|
1287 |
+
# initialize the parameters
|
1288 |
+
if self._is_initialized:
|
1289 |
+
self.params = self.convert_unroll_to_scan(self.params)
|
1290 |
+
|
1291 |
+
def disable_scan(self):
|
1292 |
+
self.use_scan = False
|
1293 |
+
self._module = self.module_class(
|
1294 |
+
config=self.config,
|
1295 |
+
dtype=self.dtype,
|
1296 |
+
use_scan=self.use_scan,
|
1297 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
1298 |
+
)
|
1299 |
+
init_fn = partial(self.init_weights, input_shape=self.input_shape)
|
1300 |
+
params_shape_tree = jax.eval_shape(init_fn, self.key)
|
1301 |
+
|
1302 |
+
# get the shape of the parameters
|
1303 |
+
self._params_shape_tree = params_shape_tree
|
1304 |
+
|
1305 |
+
# save required_params as set
|
1306 |
+
self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
|
1307 |
+
|
1308 |
+
# initialize the parameters
|
1309 |
+
if self._is_initialized:
|
1310 |
+
self.params = self.convert_scan_to_unroll(self.params)
|
1311 |
+
|
1312 |
+
def convert_unroll_to_scan(self, params: Union[Dict, FrozenDict]):
|
1313 |
+
r"""
|
1314 |
+
Convert a `PyTree` of unrolled model parameters to a scanned block of model parameters. This method can be used
|
1315 |
+
to explicitly convert the model parameters to scanned format. This returns a new `params` tree and does not
|
1316 |
+
convert the `params` in place.
|
1317 |
+
|
1318 |
+
To illustrate the workings of this method, take the Flax BERT model. The unrolled structure for the query
|
1319 |
+
projection params is as follows:
|
1320 |
+
('bert', 'encoder', 'layer', '0', 'self_attn', 'q_proj') ('bert', 'encoder', 'layer', '1', 'self_attn',
|
1321 |
+
'q_proj') ... ('bert', 'encoder', 'layer', '23', 'self_attn', 'q_proj')
|
1322 |
+
This method takes each of the `q_proj` matrices for layers (0, ..., 23) and stacks them into a single 'super'
|
1323 |
+
matrix, giving a *single* block of weights for all 24 layers compatible with the scanned model:
|
1324 |
+
('bert', 'encoder', 'layer', 'ScanLayers', 'self_attn', 'q_proj')
|
1325 |
+
|
1326 |
+
When enabling scan with _do_init=True (default), this method will be called automatically under the hood. With
|
1327 |
+
_do_init=False, it will have to be called explicitly (see example below).
|
1328 |
+
|
1329 |
+
Arguments:
|
1330 |
+
params (`Union[Dict, FrozenDict]`):
|
1331 |
+
A `PyTree` of model parameters.
|
1332 |
+
|
1333 |
+
Examples:
|
1334 |
+
|
1335 |
+
```python
|
1336 |
+
>>> from distil_whisper import FlaxWhisperForConditionalGeneration
|
1337 |
+
|
1338 |
+
>>> # Download model and configuration from huggingface.co
|
1339 |
+
>>> model, params = FlaxWhisperModel.from_pretrained("openai/whisper-tiny.en", _do_init=False)
|
1340 |
+
>>> # By default, the model params will be in unrolled format. To illustrate the use of this method,
|
1341 |
+
>>> # we'll first convert to scan format and then back to unrolled
|
1342 |
+
>>> model.enable_scan()
|
1343 |
+
>>> params = model.convert_unroll_to_scan(params)
|
1344 |
+
>>> # now convert back to unrolled
|
1345 |
+
>>> model.disable_scan()
|
1346 |
+
>>> params = model.convert_scan_to_unroll(params)
|
1347 |
+
```"""
|
1348 |
+
if isinstance(params, FrozenDict):
|
1349 |
+
params = unfreeze(params)
|
1350 |
+
|
1351 |
+
params = flatten_dict(params, sep="/")
|
1352 |
+
keys = list(params.keys())
|
1353 |
+
|
1354 |
+
for k in keys:
|
1355 |
+
# Identify all "unrolled" layers formed as part of the FlaxBertLayerCollection
|
1356 |
+
# These params contain the identifier `layer` in their key
|
1357 |
+
if "layers/0" in k:
|
1358 |
+
if "decoder" in k:
|
1359 |
+
block_prefix = "Decoder"
|
1360 |
+
num_hidden_layers = self.config.decoder_layers
|
1361 |
+
else:
|
1362 |
+
block_prefix = "Encoder"
|
1363 |
+
num_hidden_layers = self.config.encoder_layers
|
1364 |
+
|
1365 |
+
# Squash the keys for the N unrolled layers into one single key:
|
1366 |
+
# (layer/0, ..., layer/N) -> layer/FlaxScanLayers
|
1367 |
+
scan_key = k.replace("0", f"Flax{block_prefix}ScanLayers")
|
1368 |
+
stacked_params = []
|
1369 |
+
|
1370 |
+
# Iterate over the unrolled layers (1,...,N)
|
1371 |
+
for i in range(num_hidden_layers):
|
1372 |
+
# Stack the params for the N layers into one super block
|
1373 |
+
# and remove the unrolled layer params on the fly
|
1374 |
+
# -> no memory overhead for conversion!
|
1375 |
+
unrolled_layer = params.pop(k.replace("0", str(i)))
|
1376 |
+
stacked_params.append(unrolled_layer)
|
1377 |
+
|
1378 |
+
params[scan_key] = jnp.stack(stacked_params)
|
1379 |
+
|
1380 |
+
# Finally, unflatten the dict to restore the nested pytree structure
|
1381 |
+
params = unflatten_dict(params, sep="/")
|
1382 |
+
return params
|
1383 |
+
|
1384 |
+
def convert_scan_to_unroll(self, params: Union[Dict, FrozenDict]):
|
1385 |
+
r"""
|
1386 |
+
Convert a `PyTree` of scanned model parameters to an unrolled stack of model parameters. This method can be
|
1387 |
+
used to explicitly convert the model parameters to unrolled format. This returns a new `params` tree and does
|
1388 |
+
not convert the `params` in place.
|
1389 |
+
|
1390 |
+
To illustrate the workings of this method, take the Flax BERT model. The scanned structure for the query
|
1391 |
+
projection (`q_proj`) params is a single, stacked matrix of parameters over all N layers:
|
1392 |
+
('bert', 'encoder', 'layer', 'FlaxScanLayers', 'self_attn', 'q_proj')
|
1393 |
+
|
1394 |
+
This method slices each layer of the `q_proj` scanned matrix into single, standalone layers, and replaces the
|
1395 |
+
scanned matrix of parameteres on the fly:
|
1396 |
+
('bert', 'encoder', 'layer', '0', 'self_attn', 'q_proj') ('bert', 'encoder', 'layer', '1', 'self_attn',
|
1397 |
+
'q_proj') ... ('bert', 'encoder', 'layer', 'N', 'self_attn', 'q_proj')
|
1398 |
+
|
1399 |
+
When enabling scan with _do_init=True (default), this method will be called automatically under the hood. With
|
1400 |
+
_do_init=False, it will have to be called explicitly (see example below).
|
1401 |
+
|
1402 |
+
Arguments:
|
1403 |
+
params (`Union[Dict, FrozenDict]`):
|
1404 |
+
A `PyTree` of model parameters.
|
1405 |
+
|
1406 |
+
Examples:
|
1407 |
+
|
1408 |
+
```python
|
1409 |
+
>>> from distil_whisper import FlaxWhisperForConditionalGeneration
|
1410 |
+
|
1411 |
+
>>> # Download model and configuration from huggingface.co
|
1412 |
+
>>> model, params = FlaxWhisperModel.from_pretrained("openai/whisper-tiny.en", _do_init=False)
|
1413 |
+
>>> # By default, the model params will be in unrolled format. To illustrate the use of this method,
|
1414 |
+
>>> # we'll first convert to scan format and then back to unrolled
|
1415 |
+
>>> model.enable_scan()
|
1416 |
+
>>> params = model.convert_unroll_to_scan(params)
|
1417 |
+
>>> # now convert back to unrolled
|
1418 |
+
>>> model.disable_scan()
|
1419 |
+
>>> params = model.convert_scan_to_unroll(params)
|
1420 |
+
```"""
|
1421 |
+
|
1422 |
+
if isinstance(params, FrozenDict):
|
1423 |
+
params = unfreeze(params)
|
1424 |
+
|
1425 |
+
params = flatten_dict(params, sep="/")
|
1426 |
+
keys = list(params.keys())
|
1427 |
+
|
1428 |
+
for k in keys:
|
1429 |
+
# Identify all "scan" layers formed as part of the FlaxBertLayerCollection
|
1430 |
+
# These params contain the identifier `FlaxScanLayers` in their key
|
1431 |
+
if "FlaxEncoderScanLayers" in k:
|
1432 |
+
# Remove the scan layer from the PyTree of params
|
1433 |
+
scan_layer = params.pop(k)
|
1434 |
+
|
1435 |
+
# Unroll the key for the stacked scan matrix into N separate keys, indexed by layer number
|
1436 |
+
# layer/FlaxScanLayers -> (layer/0, ..., layer/N)
|
1437 |
+
for i in range(self.config.encoder_layers):
|
1438 |
+
# Unstack the params for the i-th scan layer to unrolled
|
1439 |
+
# and remove corresponding scan params on the fly
|
1440 |
+
# -> no memory overhead for conversion!
|
1441 |
+
unrolled_key = k.replace("FlaxEncoderScanLayers", str(i))
|
1442 |
+
params[unrolled_key], scan_layer = scan_layer[0], scan_layer[1:]
|
1443 |
+
|
1444 |
+
elif "FlaxDecoderScanLayers" in k:
|
1445 |
+
# Remove the scan layer from the PyTree of params
|
1446 |
+
scan_layer = params.pop(k)
|
1447 |
+
|
1448 |
+
# Unroll the key for the stacked scan matrix into N separate keys, indexed by layer number
|
1449 |
+
# layer/FlaxScanLayers -> (layer/0, ..., layer/N)
|
1450 |
+
for i in range(self.config.decoder_layers):
|
1451 |
+
# Unstack the params for the i-th scan layer to unrolled
|
1452 |
+
# and remove corresponding scan params on the fly
|
1453 |
+
# -> no memory overhead for conversion!
|
1454 |
+
unrolled_key = k.replace("FlaxDecoderScanLayers", str(i))
|
1455 |
+
params[unrolled_key], scan_layer = scan_layer[0], scan_layer[1:]
|
1456 |
+
|
1457 |
+
params = unflatten_dict(params, sep="/")
|
1458 |
+
return params
|
1459 |
+
|
1460 |
+
# Copied from transformers.models.whisper.modeling_flax_whisper.FlaxWhisperPreTrainedModel.init_cache
|
1461 |
+
def init_cache(self, batch_size, max_length, encoder_outputs):
|
1462 |
+
r"""
|
1463 |
+
Args:
|
1464 |
+
batch_size (`int`):
|
1465 |
+
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
1466 |
+
max_length (`int`):
|
1467 |
+
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
1468 |
+
cache.
|
1469 |
+
encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
|
1470 |
+
`encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
|
1471 |
+
`attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
|
1472 |
+
is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
|
1473 |
+
cross-attention of the decoder.
|
1474 |
+
"""
|
1475 |
+
# init input variables to retrieve cache
|
1476 |
+
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
1477 |
+
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
1478 |
+
decoder_position_ids = jnp.broadcast_to(
|
1479 |
+
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]),
|
1480 |
+
decoder_input_ids.shape,
|
1481 |
+
)
|
1482 |
+
|
1483 |
+
def _decoder_forward(
|
1484 |
+
module,
|
1485 |
+
decoder_input_ids,
|
1486 |
+
decoder_attention_mask,
|
1487 |
+
decoder_position_ids,
|
1488 |
+
**kwargs,
|
1489 |
+
):
|
1490 |
+
decoder_module = module._get_decoder_module()
|
1491 |
+
return decoder_module(
|
1492 |
+
decoder_input_ids,
|
1493 |
+
decoder_attention_mask,
|
1494 |
+
decoder_position_ids,
|
1495 |
+
**kwargs,
|
1496 |
+
)
|
1497 |
+
|
1498 |
+
init_variables = self.module.init(
|
1499 |
+
jax.random.PRNGKey(0),
|
1500 |
+
decoder_input_ids=decoder_input_ids,
|
1501 |
+
decoder_attention_mask=decoder_attention_mask,
|
1502 |
+
decoder_position_ids=decoder_position_ids,
|
1503 |
+
encoder_hidden_states=encoder_outputs[0],
|
1504 |
+
init_cache=True,
|
1505 |
+
method=_decoder_forward, # we only need to call the decoder to init the cache
|
1506 |
+
)
|
1507 |
+
return unfreeze(init_variables["cache"])
|
1508 |
+
|
1509 |
+
@add_start_docstrings(WHISPER_ENCODE_INPUTS_DOCSTRING)
|
1510 |
+
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=WhisperConfig)
|
1511 |
+
def encode(
|
1512 |
+
self,
|
1513 |
+
input_features: jnp.ndarray,
|
1514 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
1515 |
+
output_attentions: Optional[bool] = None,
|
1516 |
+
output_hidden_states: Optional[bool] = None,
|
1517 |
+
return_dict: Optional[bool] = None,
|
1518 |
+
train: bool = False,
|
1519 |
+
params: dict = None,
|
1520 |
+
dropout_rng: PRNGKey = None,
|
1521 |
+
**kwargs,
|
1522 |
+
):
|
1523 |
+
r"""
|
1524 |
+
Returns:
|
1525 |
+
|
1526 |
+
Example:
|
1527 |
+
|
1528 |
+
```python
|
1529 |
+
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
|
1530 |
+
>>> from datasets import load_dataset
|
1531 |
+
|
1532 |
+
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
1533 |
+
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
|
1534 |
+
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
1535 |
+
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
|
1536 |
+
>>> input_features = inputs.input_features
|
1537 |
+
>>> encoder_outputs = model.encode(input_features=input_features)
|
1538 |
+
```"""
|
1539 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1540 |
+
output_hidden_states = (
|
1541 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1542 |
+
)
|
1543 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
1544 |
+
|
1545 |
+
# Handle any PRNG if needed
|
1546 |
+
rngs = {}
|
1547 |
+
if dropout_rng is not None:
|
1548 |
+
rngs["dropout"] = dropout_rng
|
1549 |
+
|
1550 |
+
def _encoder_forward(module, input_features, **kwargs):
|
1551 |
+
encode_module = module._get_encoder_module()
|
1552 |
+
return encode_module(input_features, **kwargs)
|
1553 |
+
|
1554 |
+
return self.module.apply(
|
1555 |
+
{"params": params or self.params},
|
1556 |
+
input_features=jnp.array(input_features, dtype="f4"),
|
1557 |
+
output_attentions=output_attentions,
|
1558 |
+
output_hidden_states=output_hidden_states,
|
1559 |
+
return_dict=return_dict,
|
1560 |
+
deterministic=not train,
|
1561 |
+
rngs=rngs,
|
1562 |
+
method=_encoder_forward,
|
1563 |
+
)
|
1564 |
+
|
1565 |
+
@add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING)
|
1566 |
+
@replace_return_docstrings(
|
1567 |
+
output_type=FlaxBaseModelOutputWithPastAndCrossAttentions,
|
1568 |
+
config_class=WhisperConfig,
|
1569 |
+
)
|
1570 |
+
def decode(
|
1571 |
+
self,
|
1572 |
+
decoder_input_ids,
|
1573 |
+
encoder_outputs,
|
1574 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
1575 |
+
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
1576 |
+
decoder_position_ids: Optional[jnp.ndarray] = None,
|
1577 |
+
past_key_values: dict = None,
|
1578 |
+
output_attentions: Optional[bool] = None,
|
1579 |
+
output_hidden_states: Optional[bool] = None,
|
1580 |
+
return_dict: Optional[bool] = None,
|
1581 |
+
train: bool = False,
|
1582 |
+
params: dict = None,
|
1583 |
+
dropout_rng: PRNGKey = None,
|
1584 |
+
):
|
1585 |
+
r"""
|
1586 |
+
Returns:
|
1587 |
+
|
1588 |
+
Example:
|
1589 |
+
|
1590 |
+
```python
|
1591 |
+
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
|
1592 |
+
>>> from datasets import load_dataset
|
1593 |
+
|
1594 |
+
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
1595 |
+
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
|
1596 |
+
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
1597 |
+
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
|
1598 |
+
>>> input_features = inputs.input_features
|
1599 |
+
>>> encoder_outputs = model.encode(input_features=input_features)
|
1600 |
+
>>> decoder_start_token_id = model.config.decoder_start_token_id
|
1601 |
+
|
1602 |
+
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
|
1603 |
+
|
1604 |
+
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
|
1605 |
+
>>> last_decoder_hidden_states = outputs.last_hidden_state
|
1606 |
+
```"""
|
1607 |
+
|
1608 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1609 |
+
output_hidden_states = (
|
1610 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1611 |
+
)
|
1612 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
1613 |
+
|
1614 |
+
encoder_hidden_states = encoder_outputs[0]
|
1615 |
+
|
1616 |
+
batch_size, sequence_length = decoder_input_ids.shape
|
1617 |
+
if decoder_position_ids is None:
|
1618 |
+
if past_key_values is not None:
|
1619 |
+
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
|
1620 |
+
|
1621 |
+
if decoder_attention_mask is not None:
|
1622 |
+
decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
|
1623 |
+
else:
|
1624 |
+
decoder_position_ids = jnp.broadcast_to(
|
1625 |
+
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
1626 |
+
)
|
1627 |
+
|
1628 |
+
if decoder_attention_mask is None:
|
1629 |
+
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
1630 |
+
|
1631 |
+
# Handle any PRNG if needed
|
1632 |
+
rngs = {}
|
1633 |
+
if dropout_rng is not None:
|
1634 |
+
rngs["dropout"] = dropout_rng
|
1635 |
+
|
1636 |
+
inputs = {"params": params or self.params}
|
1637 |
+
|
1638 |
+
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
|
1639 |
+
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
|
1640 |
+
# it can be changed by FlaxWhisperAttention module
|
1641 |
+
if past_key_values:
|
1642 |
+
inputs["cache"] = past_key_values
|
1643 |
+
mutable = ["cache"]
|
1644 |
+
else:
|
1645 |
+
mutable = False
|
1646 |
+
|
1647 |
+
def _decoder_forward(
|
1648 |
+
module,
|
1649 |
+
decoder_input_ids,
|
1650 |
+
decoder_attention_mask,
|
1651 |
+
decoder_position_ids,
|
1652 |
+
**kwargs,
|
1653 |
+
):
|
1654 |
+
decoder_module = module._get_decoder_module()
|
1655 |
+
return decoder_module(
|
1656 |
+
input_ids=decoder_input_ids,
|
1657 |
+
attention_mask=decoder_attention_mask,
|
1658 |
+
position_ids=decoder_position_ids,
|
1659 |
+
**kwargs,
|
1660 |
+
)
|
1661 |
+
|
1662 |
+
outputs = self.module.apply(
|
1663 |
+
inputs,
|
1664 |
+
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
|
1665 |
+
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
|
1666 |
+
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
|
1667 |
+
encoder_hidden_states=encoder_hidden_states,
|
1668 |
+
output_attentions=output_attentions,
|
1669 |
+
output_hidden_states=output_hidden_states,
|
1670 |
+
return_dict=return_dict,
|
1671 |
+
deterministic=not train,
|
1672 |
+
rngs=rngs,
|
1673 |
+
mutable=mutable,
|
1674 |
+
method=_decoder_forward,
|
1675 |
+
)
|
1676 |
+
|
1677 |
+
# add updated cache to model output
|
1678 |
+
if past_key_values is not None and return_dict:
|
1679 |
+
outputs, past = outputs
|
1680 |
+
outputs["past_key_values"] = unfreeze(past["cache"])
|
1681 |
+
return outputs
|
1682 |
+
elif past_key_values is not None and not return_dict:
|
1683 |
+
outputs, past = outputs
|
1684 |
+
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
1685 |
+
|
1686 |
+
return outputs
|
1687 |
+
|
1688 |
+
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
|
1689 |
+
def __call__(
|
1690 |
+
self,
|
1691 |
+
input_features: jnp.ndarray,
|
1692 |
+
decoder_input_ids: jnp.ndarray,
|
1693 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
1694 |
+
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
1695 |
+
position_ids: Optional[jnp.ndarray] = None,
|
1696 |
+
decoder_position_ids: Optional[jnp.ndarray] = None,
|
1697 |
+
output_attentions: Optional[bool] = None,
|
1698 |
+
output_hidden_states: Optional[bool] = None,
|
1699 |
+
freeze_encoder: Optional[bool] = None,
|
1700 |
+
return_dict: Optional[bool] = None,
|
1701 |
+
train: bool = False,
|
1702 |
+
params: dict = None,
|
1703 |
+
dropout_rng: PRNGKey = None,
|
1704 |
+
):
|
1705 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1706 |
+
output_hidden_states = (
|
1707 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1708 |
+
)
|
1709 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
1710 |
+
|
1711 |
+
# prepare decoder inputs
|
1712 |
+
if decoder_position_ids is None:
|
1713 |
+
if decoder_attention_mask is not None:
|
1714 |
+
decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
|
1715 |
+
else:
|
1716 |
+
batch_size, sequence_length = decoder_input_ids.shape
|
1717 |
+
decoder_position_ids = jnp.broadcast_to(
|
1718 |
+
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
1719 |
+
)
|
1720 |
+
if decoder_attention_mask is None:
|
1721 |
+
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
1722 |
+
|
1723 |
+
# Handle any PRNG if needed
|
1724 |
+
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
|
1725 |
+
|
1726 |
+
return self.module.apply(
|
1727 |
+
{"params": params or self.params},
|
1728 |
+
input_features=jnp.array(input_features, dtype="f4"),
|
1729 |
+
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
|
1730 |
+
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
|
1731 |
+
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
|
1732 |
+
output_attentions=output_attentions,
|
1733 |
+
output_hidden_states=output_hidden_states,
|
1734 |
+
freeze_encoder=freeze_encoder,
|
1735 |
+
return_dict=return_dict,
|
1736 |
+
deterministic=not train,
|
1737 |
+
rngs=rngs,
|
1738 |
+
)
|
1739 |
+
|
1740 |
+
|
1741 |
+
@add_start_docstrings(
|
1742 |
+
("The bare Whisper Model transformer outputting raw hidden-states without any specific head on top."),
|
1743 |
+
WHISPER_START_DOCSTRING,
|
1744 |
+
)
|
1745 |
+
class FlaxWhisperModel(FlaxWhisperPreTrainedModel):
|
1746 |
+
config: WhisperConfig
|
1747 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
1748 |
+
params_dtype: jnp.dtype = jnp.float32
|
1749 |
+
module_class = FlaxWhisperModule
|
1750 |
+
|
1751 |
+
|
1752 |
+
append_call_sample_docstring(FlaxWhisperModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
|
1753 |
+
|
1754 |
+
|
1755 |
+
class FlaxWhisperForConditionalGenerationModule(nn.Module):
|
1756 |
+
config: WhisperConfig
|
1757 |
+
dtype: jnp.dtype = jnp.float32
|
1758 |
+
params_dtype: jnp.dtype = jnp.float32
|
1759 |
+
use_scan: bool = False
|
1760 |
+
gradient_checkpointing: bool = False
|
1761 |
+
|
1762 |
+
def setup(self) -> None:
|
1763 |
+
self.model = FlaxWhisperModule(
|
1764 |
+
config=self.config,
|
1765 |
+
dtype=self.dtype,
|
1766 |
+
params_dtype=self.params_dtype,
|
1767 |
+
use_scan=self.use_scan,
|
1768 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
1769 |
+
)
|
1770 |
+
self.lm_head = DenseGeneral(
|
1771 |
+
self.config.vocab_size,
|
1772 |
+
use_bias=False,
|
1773 |
+
dtype=self.dtype,
|
1774 |
+
params_dtype=self.params_dtype,
|
1775 |
+
kernel_axes=("embed", "vocab"),
|
1776 |
+
)
|
1777 |
+
|
1778 |
+
def _get_encoder_module(self):
|
1779 |
+
return self.model.encoder
|
1780 |
+
|
1781 |
+
def _get_decoder_module(self):
|
1782 |
+
return self.model.decoder
|
1783 |
+
|
1784 |
+
def __call__(
|
1785 |
+
self,
|
1786 |
+
input_features,
|
1787 |
+
decoder_input_ids,
|
1788 |
+
decoder_attention_mask: jnp.ndarray = None,
|
1789 |
+
decoder_position_ids: jnp.ndarray = None,
|
1790 |
+
position_ids: jnp.ndarray = None,
|
1791 |
+
attention_mask: jnp.ndarray = None,
|
1792 |
+
output_attentions: bool = False,
|
1793 |
+
output_hidden_states: bool = False,
|
1794 |
+
freeze_encoder: bool = False,
|
1795 |
+
return_dict: bool = True,
|
1796 |
+
deterministic: bool = True,
|
1797 |
+
):
|
1798 |
+
outputs = self.model(
|
1799 |
+
input_features=input_features,
|
1800 |
+
decoder_input_ids=decoder_input_ids,
|
1801 |
+
decoder_attention_mask=decoder_attention_mask,
|
1802 |
+
decoder_position_ids=decoder_position_ids,
|
1803 |
+
output_attentions=output_attentions,
|
1804 |
+
output_hidden_states=output_hidden_states,
|
1805 |
+
freeze_encoder=freeze_encoder,
|
1806 |
+
return_dict=return_dict,
|
1807 |
+
deterministic=deterministic,
|
1808 |
+
)
|
1809 |
+
|
1810 |
+
hidden_states = outputs[0]
|
1811 |
+
|
1812 |
+
if self.config.tie_word_embeddings:
|
1813 |
+
shared_embedding = self.model.decoder.embed_tokens.variables["params"]["embedding"]
|
1814 |
+
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
|
1815 |
+
else:
|
1816 |
+
lm_logits = self.lm_head(hidden_states)
|
1817 |
+
|
1818 |
+
if not return_dict:
|
1819 |
+
output = (lm_logits,) + outputs[1:]
|
1820 |
+
return output
|
1821 |
+
|
1822 |
+
return FlaxSeq2SeqLMOutput(
|
1823 |
+
logits=lm_logits,
|
1824 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
1825 |
+
decoder_attentions=outputs.decoder_attentions,
|
1826 |
+
cross_attentions=outputs.cross_attentions,
|
1827 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
1828 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
1829 |
+
encoder_attentions=outputs.encoder_attentions,
|
1830 |
+
)
|
1831 |
+
|
1832 |
+
|
1833 |
+
@add_start_docstrings("The Whisper Model with a language modeling head.", WHISPER_START_DOCSTRING)
|
1834 |
+
class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
|
1835 |
+
module_class = FlaxWhisperForConditionalGenerationModule
|
1836 |
+
|
1837 |
+
@add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING)
|
1838 |
+
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=WhisperConfig)
|
1839 |
+
def decode(
|
1840 |
+
self,
|
1841 |
+
decoder_input_ids,
|
1842 |
+
encoder_outputs,
|
1843 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
1844 |
+
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
1845 |
+
decoder_position_ids: Optional[jnp.ndarray] = None,
|
1846 |
+
past_key_values: dict = None,
|
1847 |
+
output_attentions: Optional[bool] = None,
|
1848 |
+
output_hidden_states: Optional[bool] = None,
|
1849 |
+
return_dict: Optional[bool] = None,
|
1850 |
+
train: bool = False,
|
1851 |
+
params: dict = None,
|
1852 |
+
dropout_rng: PRNGKey = None,
|
1853 |
+
):
|
1854 |
+
r"""
|
1855 |
+
Returns:
|
1856 |
+
|
1857 |
+
Example:
|
1858 |
+
|
1859 |
+
```python
|
1860 |
+
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
|
1861 |
+
>>> from datasets import load_dataset
|
1862 |
+
|
1863 |
+
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
1864 |
+
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
|
1865 |
+
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
1866 |
+
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
|
1867 |
+
>>> input_features = inputs.input_features
|
1868 |
+
>>> encoder_outputs = model.encode(input_features=input_features)
|
1869 |
+
>>> decoder_start_token_id = model.config.decoder_start_token_id
|
1870 |
+
|
1871 |
+
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
|
1872 |
+
|
1873 |
+
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
|
1874 |
+
>>> last_decoder_hidden_states = outputs.last_hidden_state
|
1875 |
+
```"""
|
1876 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1877 |
+
output_hidden_states = (
|
1878 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1879 |
+
)
|
1880 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
1881 |
+
|
1882 |
+
encoder_hidden_states = encoder_outputs[0]
|
1883 |
+
|
1884 |
+
batch_size, sequence_length = decoder_input_ids.shape
|
1885 |
+
if decoder_position_ids is None:
|
1886 |
+
if past_key_values is not None:
|
1887 |
+
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
|
1888 |
+
|
1889 |
+
if decoder_attention_mask is not None:
|
1890 |
+
decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
|
1891 |
+
else:
|
1892 |
+
decoder_position_ids = jnp.broadcast_to(
|
1893 |
+
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
1894 |
+
)
|
1895 |
+
if decoder_attention_mask is None:
|
1896 |
+
decoder_attention_mask = jnp.ones((batch_size, sequence_length), dtype="i4")
|
1897 |
+
|
1898 |
+
# Handle any PRNG if needed
|
1899 |
+
rngs = {}
|
1900 |
+
if dropout_rng is not None:
|
1901 |
+
rngs["dropout"] = dropout_rng
|
1902 |
+
|
1903 |
+
inputs = {"params": params or self.params}
|
1904 |
+
|
1905 |
+
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
|
1906 |
+
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
|
1907 |
+
# it can be changed by FlaxWhisperAttention module
|
1908 |
+
if past_key_values:
|
1909 |
+
inputs["cache"] = past_key_values
|
1910 |
+
mutable = ["cache"]
|
1911 |
+
else:
|
1912 |
+
mutable = False
|
1913 |
+
|
1914 |
+
def _decoder_forward(
|
1915 |
+
module,
|
1916 |
+
decoder_input_ids,
|
1917 |
+
decoder_attention_mask,
|
1918 |
+
decoder_position_ids,
|
1919 |
+
**kwargs,
|
1920 |
+
):
|
1921 |
+
decoder_module = module._get_decoder_module()
|
1922 |
+
outputs = decoder_module(
|
1923 |
+
input_ids=decoder_input_ids,
|
1924 |
+
attention_mask=decoder_attention_mask,
|
1925 |
+
position_ids=decoder_position_ids,
|
1926 |
+
**kwargs,
|
1927 |
+
)
|
1928 |
+
hidden_states = outputs[0]
|
1929 |
+
|
1930 |
+
if self.config.tie_word_embeddings:
|
1931 |
+
shared_embedding = module.model.decoder.embed_tokens.variables["params"]["embedding"]
|
1932 |
+
lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
|
1933 |
+
else:
|
1934 |
+
lm_logits = module.lm_head(hidden_states)
|
1935 |
+
|
1936 |
+
return lm_logits, outputs
|
1937 |
+
|
1938 |
+
outputs = self.module.apply(
|
1939 |
+
inputs,
|
1940 |
+
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
|
1941 |
+
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
|
1942 |
+
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
|
1943 |
+
encoder_hidden_states=encoder_hidden_states,
|
1944 |
+
output_attentions=output_attentions,
|
1945 |
+
output_hidden_states=output_hidden_states,
|
1946 |
+
return_dict=return_dict,
|
1947 |
+
deterministic=not train,
|
1948 |
+
rngs=rngs,
|
1949 |
+
mutable=mutable,
|
1950 |
+
method=_decoder_forward,
|
1951 |
+
)
|
1952 |
+
|
1953 |
+
if past_key_values is None:
|
1954 |
+
lm_logits, decoder_outputs = outputs
|
1955 |
+
else:
|
1956 |
+
(lm_logits, decoder_outputs), past = outputs
|
1957 |
+
|
1958 |
+
if return_dict:
|
1959 |
+
outputs = FlaxCausalLMOutputWithCrossAttentions(
|
1960 |
+
logits=lm_logits,
|
1961 |
+
hidden_states=decoder_outputs.hidden_states,
|
1962 |
+
attentions=decoder_outputs.attentions,
|
1963 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
1964 |
+
)
|
1965 |
+
else:
|
1966 |
+
outputs = (lm_logits,) + decoder_outputs[1:]
|
1967 |
+
|
1968 |
+
# add updated cache to model output
|
1969 |
+
if past_key_values is not None and return_dict:
|
1970 |
+
outputs["past_key_values"] = unfreeze(past["cache"])
|
1971 |
+
return outputs
|
1972 |
+
elif past_key_values is not None and not return_dict:
|
1973 |
+
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
1974 |
+
|
1975 |
+
return outputs
|
1976 |
+
|
1977 |
+
def generate(
|
1978 |
+
self,
|
1979 |
+
input_features,
|
1980 |
+
generation_config=None,
|
1981 |
+
logits_processor=None,
|
1982 |
+
return_timestamps=None,
|
1983 |
+
task=None,
|
1984 |
+
language=None,
|
1985 |
+
is_multilingual=None,
|
1986 |
+
**kwargs,
|
1987 |
+
):
|
1988 |
+
if generation_config is None:
|
1989 |
+
generation_config = self.generation_config
|
1990 |
+
|
1991 |
+
if return_timestamps is not None:
|
1992 |
+
generation_config.return_timestamps = return_timestamps
|
1993 |
+
|
1994 |
+
if task is not None:
|
1995 |
+
generation_config.task = task
|
1996 |
+
|
1997 |
+
if is_multilingual is not None:
|
1998 |
+
generation_config.is_multilingual = is_multilingual
|
1999 |
+
|
2000 |
+
if language is not None:
|
2001 |
+
generation_config.language = language
|
2002 |
+
|
2003 |
+
if kwargs is not None and "decoder_input_ids" in kwargs:
|
2004 |
+
decoder_input_length = len(kwargs["decoder_input_ids"])
|
2005 |
+
else:
|
2006 |
+
decoder_input_length = 1
|
2007 |
+
|
2008 |
+
forced_decoder_ids = []
|
2009 |
+
|
2010 |
+
if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual:
|
2011 |
+
if hasattr(generation_config, "language"):
|
2012 |
+
forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language]))
|
2013 |
+
else:
|
2014 |
+
forced_decoder_ids.append((1, None))
|
2015 |
+
|
2016 |
+
if hasattr(generation_config, "task"):
|
2017 |
+
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
|
2018 |
+
else:
|
2019 |
+
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
|
2020 |
+
|
2021 |
+
if (
|
2022 |
+
hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps
|
2023 |
+
) or return_timestamps:
|
2024 |
+
logits_processor = [
|
2025 |
+
FlaxWhisperTimeStampLogitsProcessor(generation_config, self.config, decoder_input_length)
|
2026 |
+
]
|
2027 |
+
else:
|
2028 |
+
if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
|
2029 |
+
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
|
2030 |
+
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
|
2031 |
+
|
2032 |
+
if len(forced_decoder_ids) > 0:
|
2033 |
+
generation_config.forced_decoder_ids = forced_decoder_ids
|
2034 |
+
|
2035 |
+
return super().generate(
|
2036 |
+
input_features,
|
2037 |
+
generation_config,
|
2038 |
+
logits_processor=logits_processor,
|
2039 |
+
**kwargs,
|
2040 |
+
)
|
2041 |
+
|
2042 |
+
def pipeline_generate(
|
2043 |
+
self,
|
2044 |
+
input_features,
|
2045 |
+
forced_decoder_ids,
|
2046 |
+
return_timestamps=False,
|
2047 |
+
generation_config=None,
|
2048 |
+
**kwargs,
|
2049 |
+
):
|
2050 |
+
if generation_config is None:
|
2051 |
+
generation_config = self.generation_config
|
2052 |
+
|
2053 |
+
# override the generation config forced decoder ids in preference of the ones we have set
|
2054 |
+
generation_config.forced_decoder_ids = None
|
2055 |
+
|
2056 |
+
logits_processor = FlaxLogitsProcessorList()
|
2057 |
+
logits_processor.append(FlaxStaticForceTokensLogitsProcessor(forced_decoder_ids))
|
2058 |
+
|
2059 |
+
if hasattr(generation_config, "return_timestamps") and return_timestamps:
|
2060 |
+
logits_processor.append(FlaxWhisperTimeStampLogitsProcessor(generation_config, self.config, 1))
|
2061 |
+
|
2062 |
+
return super().generate(
|
2063 |
+
input_features,
|
2064 |
+
generation_config,
|
2065 |
+
logits_processor=logits_processor,
|
2066 |
+
**kwargs,
|
2067 |
+
)
|
2068 |
+
|
2069 |
+
def prepare_inputs_for_generation(
|
2070 |
+
self,
|
2071 |
+
decoder_input_ids,
|
2072 |
+
max_length,
|
2073 |
+
attention_mask: Optional[jax.Array] = None,
|
2074 |
+
decoder_attention_mask: Optional[jax.Array] = None,
|
2075 |
+
encoder_outputs=None,
|
2076 |
+
**kwargs,
|
2077 |
+
):
|
2078 |
+
# initializing the cache
|
2079 |
+
batch_size, seq_length = decoder_input_ids.shape
|
2080 |
+
|
2081 |
+
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
|
2082 |
+
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
2083 |
+
# But since the decoder uses a causal mask, those positions are masked anyways.
|
2084 |
+
# Thus we can create a single static attention_mask here, which is more efficient for compilation
|
2085 |
+
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
2086 |
+
if decoder_attention_mask is not None:
|
2087 |
+
position_ids = decoder_attention_mask.cumsum(-1) - 1
|
2088 |
+
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
|
2089 |
+
else:
|
2090 |
+
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
2091 |
+
|
2092 |
+
return {
|
2093 |
+
"past_key_values": past_key_values,
|
2094 |
+
"encoder_outputs": encoder_outputs,
|
2095 |
+
"encoder_attention_mask": attention_mask,
|
2096 |
+
"decoder_attention_mask": extended_attention_mask,
|
2097 |
+
"decoder_position_ids": position_ids,
|
2098 |
+
}
|
2099 |
+
|
2100 |
+
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
2101 |
+
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
2102 |
+
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
|
2103 |
+
return model_kwargs
|
2104 |
+
|
2105 |
+
|
2106 |
+
FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING = r"""
|
2107 |
+
Returns:
|
2108 |
+
|
2109 |
+
Transcription example:
|
2110 |
+
|
2111 |
+
```python
|
2112 |
+
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
|
2113 |
+
>>> from datasets import load_dataset
|
2114 |
+
|
2115 |
+
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
2116 |
+
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
|
2117 |
+
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
2118 |
+
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
|
2119 |
+
>>> input_features = inputs.input_features
|
2120 |
+
>>> generated_ids = model.generate(input_ids=input_features)
|
2121 |
+
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
2122 |
+
>>> transcription
|
2123 |
+
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
2124 |
+
```
|
2125 |
+
"""
|
2126 |
+
|
2127 |
+
overwrite_call_docstring(
|
2128 |
+
FlaxWhisperForConditionalGeneration,
|
2129 |
+
WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING,
|
2130 |
+
)
|
2131 |
+
append_replace_return_docstrings(
|
2132 |
+
FlaxWhisperForConditionalGeneration,
|
2133 |
+
output_type=FlaxSeq2SeqLMOutput,
|
2134 |
+
config_class=_CONFIG_FOR_DOC,
|
2135 |
+
)
|
flax/distil_whisper/partitioner.py
ADDED
@@ -0,0 +1,965 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The T5X Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Utilities for partitioning."""
|
16 |
+
|
17 |
+
import abc
|
18 |
+
import collections
|
19 |
+
import dataclasses
|
20 |
+
import typing
|
21 |
+
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
22 |
+
|
23 |
+
import cached_property
|
24 |
+
import jax
|
25 |
+
import numpy as np
|
26 |
+
from absl import logging
|
27 |
+
from flax import traverse_util
|
28 |
+
from flax.linen import partitioning as flax_partitioning
|
29 |
+
from jax import numpy as jnp
|
30 |
+
from jax import random
|
31 |
+
from jax.experimental import multihost_utils
|
32 |
+
from jax.experimental.mesh_utils import create_hybrid_device_mesh
|
33 |
+
from jax.experimental.pjit import pjit as jax_pjit
|
34 |
+
from jax.sharding import Mesh, PartitionSpec
|
35 |
+
|
36 |
+
|
37 |
+
JaxDevice = Any
|
38 |
+
TpuMesh = Tuple[int, int, int, int] # (x, y, z, num_cores).
|
39 |
+
OtherMesh = Tuple[int, int]
|
40 |
+
HardwareMesh = Union[TpuMesh, OtherMesh]
|
41 |
+
PyTreeDef = type(jax.tree_util.tree_structure(None))
|
42 |
+
TrainState = Any
|
43 |
+
LogicalAxisRules = Sequence[Tuple[str, Optional[str]]]
|
44 |
+
|
45 |
+
if typing.TYPE_CHECKING: # See b/163639353
|
46 |
+
cached_property = property # pylint: disable=invalid-name
|
47 |
+
else:
|
48 |
+
cached_property = cached_property.cached_property
|
49 |
+
|
50 |
+
|
51 |
+
class AxisNames(tuple):
|
52 |
+
"""Tuple of strings specifying name for each axis.
|
53 |
+
|
54 |
+
We create a separate class for this so JAX's pytree utilities can distinguish
|
55 |
+
it from a tuple that should be treated as a pytree, instead treating it as a
|
56 |
+
leaf.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __new__(cls, *names):
|
60 |
+
return tuple.__new__(AxisNames, names)
|
61 |
+
|
62 |
+
def __repr__(self):
|
63 |
+
return "AxisNames%s" % tuple.__repr__(self)
|
64 |
+
|
65 |
+
|
66 |
+
# pjit wrappers for cpu fallback.
|
67 |
+
# ----------------------------------------------------------------------------
|
68 |
+
# TODO(levskaya): This function is now no different than jax_pjit, but callers
|
69 |
+
# currently depend on `backend` argument
|
70 |
+
def pjit(
|
71 |
+
fun: Callable, # pylint: disable=g-bare-generic
|
72 |
+
in_axis_resources,
|
73 |
+
out_axis_resources,
|
74 |
+
static_argnums: Union[int, Sequence[int]] = (),
|
75 |
+
donate_argnums: Union[int, Sequence[int]] = (),
|
76 |
+
backend: Optional[str] = None,
|
77 |
+
):
|
78 |
+
"""Wrapper for pjit."""
|
79 |
+
del backend
|
80 |
+
return jax_pjit(
|
81 |
+
fun,
|
82 |
+
in_axis_resources,
|
83 |
+
out_axis_resources,
|
84 |
+
static_argnums=static_argnums,
|
85 |
+
donate_argnums=donate_argnums,
|
86 |
+
)
|
87 |
+
|
88 |
+
|
89 |
+
# pjit wrappers for cpu fallback.
|
90 |
+
# -----------------------------------------------------------------------------
|
91 |
+
# TODO(levskaya): upstream this fallback behavior to jax pjit.
|
92 |
+
def pjit_with_cpu_fallback(
|
93 |
+
fun: Callable, # pylint: disable=g-bare-generic
|
94 |
+
in_axis_resources,
|
95 |
+
out_axis_resources,
|
96 |
+
static_argnums: Union[int, Sequence[int]] = (),
|
97 |
+
donate_argnums: Union[int, Sequence[int]] = (),
|
98 |
+
backend: Optional[str] = None,
|
99 |
+
):
|
100 |
+
"""Wrapper for pjit that calls normal jit on cpu."""
|
101 |
+
if jax.devices(backend)[0].platform == "cpu":
|
102 |
+
return jax.jit(fun, static_argnums=static_argnums, donate_argnums=donate_argnums)
|
103 |
+
else:
|
104 |
+
return jax_pjit(
|
105 |
+
fun,
|
106 |
+
in_axis_resources,
|
107 |
+
out_axis_resources,
|
108 |
+
static_argnums=static_argnums,
|
109 |
+
donate_argnums=donate_argnums,
|
110 |
+
)
|
111 |
+
|
112 |
+
|
113 |
+
def with_sharding_constraint(x, axis_resources):
|
114 |
+
"""Wrapper for pjit with_sharding_constraint, no-op on cpu or outside pjit."""
|
115 |
+
if jax.devices()[0].platform == "cpu" or not global_mesh_defined():
|
116 |
+
return x
|
117 |
+
else:
|
118 |
+
return jax.experimental.pjit.with_sharding_constraint(x, axis_resources)
|
119 |
+
|
120 |
+
|
121 |
+
# pjit Mesh creation functions.
|
122 |
+
# -----------------------------------------------------------------------------
|
123 |
+
def bounds_from_last_device(last_device: JaxDevice) -> HardwareMesh:
|
124 |
+
"""Get the bound from the given last device."""
|
125 |
+
# Must be passed the device at the highest-coordinate corner of the
|
126 |
+
# relevant mesh, which is a requirement we know is satisfied by the last
|
127 |
+
# device in jax.devices().
|
128 |
+
if hasattr(last_device, "coords"):
|
129 |
+
x, y, z = last_device.coords
|
130 |
+
return x + 1, y + 1, z + 1, last_device.core_on_chip + 1
|
131 |
+
else:
|
132 |
+
# On non-TPU platforms, the "mesh" is hosts x devices per host in order
|
133 |
+
# to take advantage of faster within-host interconnect.
|
134 |
+
return jax.host_count(), jax.local_device_count()
|
135 |
+
|
136 |
+
|
137 |
+
def get_coords(device: JaxDevice) -> HardwareMesh:
|
138 |
+
"""Returns the coordinates of the given device."""
|
139 |
+
if hasattr(device, "coords"):
|
140 |
+
return (*device.coords, device.core_on_chip)
|
141 |
+
return (device.process_index, device.id % jax.local_device_count())
|
142 |
+
|
143 |
+
|
144 |
+
def global_mesh_defined():
|
145 |
+
"""Checks if global xmap/pjit mesh resource environment is defined."""
|
146 |
+
maps_env = jax.experimental.maps.thread_resources.env
|
147 |
+
return maps_env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison
|
148 |
+
|
149 |
+
|
150 |
+
def get_mesh(
|
151 |
+
model_parallel_submesh: HardwareMesh,
|
152 |
+
input_devices: Sequence[JaxDevice] = (),
|
153 |
+
input_local_devices: Sequence[JaxDevice] = (),
|
154 |
+
tile_by_host_if_needed: bool = True,
|
155 |
+
backend: Optional[str] = None,
|
156 |
+
) -> Mesh:
|
157 |
+
"""Construct an xmap/pjit Mesh for the given model-parallel submesh.
|
158 |
+
|
159 |
+
The resulting mesh has two resource axes: 'model', with the provided submesh
|
160 |
+
shape, and 'data', which covers the rest of the mesh.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
model_parallel_submesh: a HardwareMesh spec, namely (x,y,z,core) on TPU for
|
164 |
+
a single model-parallel replica's "tile" in the physical device mesh. The
|
165 |
+
first three elements (`x`, `y`, and `z`) should be factors of the pod
|
166 |
+
slice; e.g., if you are using df_4x8, then `x` should be a factor of 4
|
167 |
+
(one of 1, 2, 4), `y` should be a factor of 8 (one of 1, 2, 4, 8), and `z`
|
168 |
+
must be 1, because TPU v3 slices are only 2D. `z` can be >1 for TPU v4
|
169 |
+
(and maybe later TPUs) that allow 3D slices. `core` is the number of cores
|
170 |
+
to use from each TPU node. As communication is usually fastest inside the
|
171 |
+
same node, if you need a tile of more than 1 core, then
|
172 |
+
you should first increase `core`: e.g., for TPU v3, (1,1,1,2) is better
|
173 |
+
than (2,1,1,1). To pick a good spec, try a few possible values until you
|
174 |
+
get high TPU utilization.
|
175 |
+
input_devices: the devices to use, will use jax.devices() if this is not
|
176 |
+
set.
|
177 |
+
input_local_devices: the local devices to use, will use jax.local_devices()
|
178 |
+
if this is not set.
|
179 |
+
tile_by_host_if_needed: JAX currently requires that the parts of any sharded
|
180 |
+
array that are located on one host's local devices form a single
|
181 |
+
contiguous slice. A best effort will be made to achieve this without
|
182 |
+
"tiling" the device assignment over hosts (which can reduce XLA collective
|
183 |
+
performance). If this flag is True, then the device assignment will be
|
184 |
+
tiled over hosts if necessary to satisfy this constraint and create a
|
185 |
+
buildable mesh; if false, mesh construction will fail instead.
|
186 |
+
backend: get devices from the pinned backend, if specified. This is
|
187 |
+
useful for explicitly specifying the devices other than relying on
|
188 |
+
jax_platform_name.
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
A xmap / pjit Mesh containing the virtual device mesh with data, model axes.
|
192 |
+
"""
|
193 |
+
input_devices = input_devices or jax.devices(backend)
|
194 |
+
input_local_devices = input_local_devices or jax.local_devices(0, backend)
|
195 |
+
# Sort input_devices based on coords, as backends might not return devices
|
196 |
+
# in order.
|
197 |
+
last_device = sorted(input_devices, key=get_coords)[-1]
|
198 |
+
last_input_local_devices = sorted(input_local_devices, key=get_coords)[-1]
|
199 |
+
logging.info(
|
200 |
+
"last device coords : %r\nlast local device coords: %r",
|
201 |
+
get_coords(last_device),
|
202 |
+
get_coords(last_input_local_devices),
|
203 |
+
)
|
204 |
+
global_hardware_mesh = bounds_from_last_device(last_device)
|
205 |
+
mesh_ndim = len(global_hardware_mesh)
|
206 |
+
local_hardware_mesh = bounds_from_last_device(last_input_local_devices)
|
207 |
+
mesh_err = (
|
208 |
+
f"each dimension of the model parallel submesh {model_parallel_submesh} "
|
209 |
+
"must be a factor of the corresponding dimension of the global device "
|
210 |
+
f"mesh {global_hardware_mesh}"
|
211 |
+
)
|
212 |
+
assert not any(g % m for g, m in zip(global_hardware_mesh, model_parallel_submesh)), mesh_err
|
213 |
+
assert not any(g % l for g, l in zip(global_hardware_mesh, local_hardware_mesh))
|
214 |
+
devices = np.empty(global_hardware_mesh, dtype=object)
|
215 |
+
for device in input_devices:
|
216 |
+
device_coords = get_coords(device)
|
217 |
+
devices[device_coords] = device
|
218 |
+
tile_by_host = tile_by_host_if_needed
|
219 |
+
if len(global_hardware_mesh) == 4:
|
220 |
+
# enable contiguous local chunks without host tiling by making Z major
|
221 |
+
global_hardware_mesh = typing.cast(Tuple[int, int, int, int], global_hardware_mesh)
|
222 |
+
model_parallel_submesh = typing.cast(Tuple[int, int, int, int], model_parallel_submesh)
|
223 |
+
gx, gy, gz, gc = global_hardware_mesh
|
224 |
+
mx, my, mz, mc = model_parallel_submesh
|
225 |
+
if (mx == gx > 1 and my == mz == 1) or (mx == 1 and my == gy > 1 and mz == gz > 1):
|
226 |
+
logging.info("ensuring YZ plane has a Z-major device order")
|
227 |
+
# YZ should be ZY
|
228 |
+
assert mc == gc, (mc, gc)
|
229 |
+
global_hardware_mesh = gx, gz, gy, gc
|
230 |
+
model_parallel_submesh = mx, mz, my, mc
|
231 |
+
devices = devices.swapaxes(1, 2)
|
232 |
+
tile_by_host = False
|
233 |
+
if (my == gy > 1 and mx == mz == 1) or (my == 1 and mx == gx > 1 and mz == gz > 1):
|
234 |
+
logging.info("ensuring XZ plane has a Z-major device order")
|
235 |
+
# XZ should be ZX
|
236 |
+
assert mc == gc, (mc, gc)
|
237 |
+
global_hardware_mesh = gz, gy, gx, gc
|
238 |
+
model_parallel_submesh = mz, my, mx, mc
|
239 |
+
devices = devices.swapaxes(0, 2)
|
240 |
+
tile_by_host = False
|
241 |
+
if tile_by_host:
|
242 |
+
logging.warning(
|
243 |
+
"Tiling device assignment mesh by hosts, which may lead to "
|
244 |
+
"reduced XLA collective performance. To avoid this, modify "
|
245 |
+
"the model parallel submesh or run with more tasks per host."
|
246 |
+
)
|
247 |
+
tile_err = (
|
248 |
+
"to tile the mesh by hosts, each dimension of the model parallel "
|
249 |
+
"submesh must be either a factor or a multiple of the corresponding "
|
250 |
+
"dimension of the per-host submesh"
|
251 |
+
)
|
252 |
+
|
253 |
+
def dh_dd_mh_md(g: int, m: int, l: int) -> Tuple[int, int, int, int]:
|
254 |
+
"""Split a global mesh dimension into four tiling components.
|
255 |
+
|
256 |
+
Args:
|
257 |
+
g: global mesh bounds dimension size
|
258 |
+
m: model-parallel submesh bounds dimension size
|
259 |
+
l: local submesh bounds dimension size
|
260 |
+
|
261 |
+
Returns:
|
262 |
+
The resulting tuple divides the dimension into the hosts component of
|
263 |
+
the data-parallel submesh, the devices component of the data-parallel
|
264 |
+
submesh, the hosts component of the model-parallel submesh, and the
|
265 |
+
devices component of the model-parallel submesh.
|
266 |
+
"""
|
267 |
+
d = g // m
|
268 |
+
if m >= l:
|
269 |
+
assert not m % l, tile_err
|
270 |
+
return (d, 1, m // l, l)
|
271 |
+
else:
|
272 |
+
assert not l % m, tile_err
|
273 |
+
return (d // (l // m), l // m, 1, m)
|
274 |
+
|
275 |
+
# e.g. [(x_data_hosts, x_data_devs, x_model_hosts, x_model_devs), ...]
|
276 |
+
dh_dd_mh_md_tups = map(
|
277 |
+
dh_dd_mh_md,
|
278 |
+
global_hardware_mesh,
|
279 |
+
model_parallel_submesh,
|
280 |
+
local_hardware_mesh,
|
281 |
+
)
|
282 |
+
# reshape to e.g. (x_dh, x_dd, x_mh, x_md, y_dh, ...)
|
283 |
+
devices = devices.reshape(*(s for t in dh_dd_mh_md_tups for s in t)) # pylint: disable=g-complex-comprehension
|
284 |
+
# TODO(jekbradbury): reorder local subgroups for ring locality
|
285 |
+
# Transpose to [data_host], [data_device], [model_host], [model_device]
|
286 |
+
# block ordering e.g. (x_dh, y_dh, ..., x_dd, y_dd, ...)
|
287 |
+
devices = devices.transpose(
|
288 |
+
*(4 * i for i in range(mesh_ndim)),
|
289 |
+
*(4 * i + 1 for i in range(mesh_ndim)),
|
290 |
+
*(4 * i + 2 for i in range(mesh_ndim)),
|
291 |
+
*(4 * i + 3 for i in range(mesh_ndim)),
|
292 |
+
)
|
293 |
+
else:
|
294 |
+
# e.g. [(x_data, x_model), (y_data, y_model), ...]
|
295 |
+
model_data_tups = [(g // m, m) for g, m in zip(global_hardware_mesh, model_parallel_submesh)]
|
296 |
+
# reshape to e.g. (x_data, x_model, y_data, y_model...)
|
297 |
+
devices = devices.reshape(*(s for t in model_data_tups for s in t)) # pylint: disable=g-complex-comprehension
|
298 |
+
# TODO(jekbradbury): reorder small subgroups for ring locality
|
299 |
+
# transpose to e.g. (x_data, y_data, ..., x_model, ...)
|
300 |
+
devices = devices.transpose(*(2 * i for i in range(mesh_ndim)), *(2 * i + 1 for i in range(mesh_ndim)))
|
301 |
+
# reshape to (data, model)
|
302 |
+
devices = devices.reshape(-1, np.prod(model_parallel_submesh))
|
303 |
+
global_mesh = Mesh(devices, ["data", "model"])
|
304 |
+
logging.info("global_mesh axis_names: %s", global_mesh.axis_names)
|
305 |
+
logging.info("global_mesh devices: %s", global_mesh.devices)
|
306 |
+
logging.info("global_mesh devices shape: %s", global_mesh.devices.shape)
|
307 |
+
return global_mesh
|
308 |
+
|
309 |
+
|
310 |
+
def get_cpu_mesh() -> Mesh:
|
311 |
+
"""Trivial mesh for CPU Testing."""
|
312 |
+
devices = np.empty((jax.host_count(), jax.local_device_count()), dtype=object)
|
313 |
+
for device in jax.devices():
|
314 |
+
devices[device.process_index, device.id % jax.local_device_count()] = device
|
315 |
+
return Mesh(devices, ["data", "model"])
|
316 |
+
|
317 |
+
|
318 |
+
def get_gpu_mesh(num_partitions: int) -> Mesh:
|
319 |
+
"""Mesh for GPUs that preferentially places 'model' on NVLink."""
|
320 |
+
nvlink_size = jax.local_device_count()
|
321 |
+
dcn_size = jax.process_count()
|
322 |
+
nvlink_mp = min(num_partitions, nvlink_size)
|
323 |
+
nvlink_dp, extra1 = divmod(nvlink_size, nvlink_mp)
|
324 |
+
dcn_mp, extra2 = divmod(num_partitions, nvlink_mp)
|
325 |
+
assert not (
|
326 |
+
extra1 or extra2
|
327 |
+
), "number of partitions on GPU must be a factor or multiple of the number of local devices"
|
328 |
+
dcn_dp = dcn_size // dcn_mp
|
329 |
+
|
330 |
+
devices = create_hybrid_device_mesh(
|
331 |
+
mesh_shape=[nvlink_dp, nvlink_mp],
|
332 |
+
dcn_mesh_shape=[dcn_dp, dcn_mp],
|
333 |
+
process_is_granule=True,
|
334 |
+
)
|
335 |
+
|
336 |
+
global_mesh = Mesh(devices, ["data", "model"])
|
337 |
+
logging.info("global_mesh axis_names: %s", global_mesh.axis_names)
|
338 |
+
logging.info("global_mesh devices: %s", global_mesh.devices)
|
339 |
+
return global_mesh
|
340 |
+
|
341 |
+
|
342 |
+
def default_mesh(
|
343 |
+
num_partitions: int,
|
344 |
+
model_parallel_submesh: Optional[HardwareMesh] = None,
|
345 |
+
backend: Optional[str] = None,
|
346 |
+
) -> Mesh:
|
347 |
+
"""Attempt to return a default mesh for simple cases.
|
348 |
+
|
349 |
+
Args:
|
350 |
+
num_partitions: number of partitions to use, will be ignored if
|
351 |
+
model_parallel_submesh is provided.
|
352 |
+
model_parallel_submesh: 4-tuple that specifies the x,y,z,c submesh to use as
|
353 |
+
the model-parallel device tile.
|
354 |
+
backend: get devices from the pinned backend, if specified. This is useful
|
355 |
+
for explicitly specifying the devices other than relying on
|
356 |
+
jax_platform_name.
|
357 |
+
|
358 |
+
Returns:
|
359 |
+
xmap/pjit 2D Mesh with 'data', 'model' mesh axes.
|
360 |
+
"""
|
361 |
+
last_device = jax.devices(backend)[-1]
|
362 |
+
platform = last_device.platform
|
363 |
+
device_kind = last_device.device_kind
|
364 |
+
bounds = bounds_from_last_device(last_device)
|
365 |
+
|
366 |
+
if model_parallel_submesh:
|
367 |
+
return get_mesh(model_parallel_submesh, backend=backend)
|
368 |
+
|
369 |
+
if platform == "cpu":
|
370 |
+
return get_cpu_mesh()
|
371 |
+
elif platform == "gpu":
|
372 |
+
return get_gpu_mesh(num_partitions)
|
373 |
+
|
374 |
+
mps = None
|
375 |
+
if device_kind in ("TPU v2", "TPU v3"):
|
376 |
+
if num_partitions == 1:
|
377 |
+
mps = (1, 1, 1, 1)
|
378 |
+
elif num_partitions == 2:
|
379 |
+
mps = (1, 1, 1, 2)
|
380 |
+
elif num_partitions == 4:
|
381 |
+
mps = (2, 1, 1, 2)
|
382 |
+
elif num_partitions == 8:
|
383 |
+
mps = (2, 2, 1, 2)
|
384 |
+
elif num_partitions == 16:
|
385 |
+
mps = (4, 2, 1, 2)
|
386 |
+
# assume the use of megacore on TPU v4
|
387 |
+
elif (device_kind == "TPU v4" or device_kind == "TPU v4 lite") and bounds[3] == 1:
|
388 |
+
if num_partitions == 1:
|
389 |
+
mps = (1, 1, 1, 1)
|
390 |
+
elif num_partitions == 2:
|
391 |
+
mps = (1, 2, 1, 1)
|
392 |
+
elif num_partitions == 4:
|
393 |
+
if bounds[0] >= 4:
|
394 |
+
mps = (4, 1, 1, 1)
|
395 |
+
else:
|
396 |
+
mps = (2, 2, 1, 1)
|
397 |
+
elif num_partitions == 8:
|
398 |
+
if bounds[2] >= 8:
|
399 |
+
mps = (1, 1, 8, 1)
|
400 |
+
else:
|
401 |
+
mps = (4, 2, 1, 1)
|
402 |
+
elif num_partitions == 16:
|
403 |
+
if bounds[2] >= 16:
|
404 |
+
mps = (1, 1, 16, 1)
|
405 |
+
elif bounds[0] >= 8:
|
406 |
+
mps = (8, 2, 1, 1)
|
407 |
+
elif bounds[0] >= 4:
|
408 |
+
mps = (4, 4, 1, 1)
|
409 |
+
else:
|
410 |
+
mps = (2, 2, 4, 1)
|
411 |
+
|
412 |
+
if mps is None:
|
413 |
+
raise ValueError(
|
414 |
+
"No default mesh for this configuration: specify " "config.model_parallel_submesh explicitly."
|
415 |
+
)
|
416 |
+
return get_mesh(mps, backend=backend)
|
417 |
+
|
418 |
+
|
419 |
+
# Data chunking helper.
|
420 |
+
# -----------------------------------------------------------------------------
|
421 |
+
@dataclasses.dataclass
|
422 |
+
class LocalChunkInfo:
|
423 |
+
# The logical slice of an array located on this host's local devices.
|
424 |
+
slice: Tuple[slice, ...]
|
425 |
+
# A unique index for this host/local chunk among chunks with the same slice.
|
426 |
+
replica_id: int
|
427 |
+
|
428 |
+
|
429 |
+
class LocalChunker:
|
430 |
+
"""Utility class to aid chunking of sharded arrays in multihost settings."""
|
431 |
+
|
432 |
+
def __init__(self, global_mesh: Mesh):
|
433 |
+
self.global_mesh = global_mesh
|
434 |
+
local_mesh = global_mesh.local_mesh
|
435 |
+
first_local_device = local_mesh.devices.reshape(-1)[0]
|
436 |
+
host_location = collections.OrderedDict(
|
437 |
+
zip(
|
438 |
+
global_mesh.shape.keys(),
|
439 |
+
list(zip(*np.nonzero(global_mesh.devices == first_local_device)))[0],
|
440 |
+
)
|
441 |
+
)
|
442 |
+
self.num_chunks = collections.OrderedDict()
|
443 |
+
self.chunk_ids = collections.OrderedDict()
|
444 |
+
self.mesh_axes = list(global_mesh.shape.keys())
|
445 |
+
for mesh_axis in self.mesh_axes:
|
446 |
+
num_devices_per_chunk = local_mesh.shape[mesh_axis]
|
447 |
+
self.num_chunks[mesh_axis] = global_mesh.shape[mesh_axis] // num_devices_per_chunk
|
448 |
+
self.chunk_ids[mesh_axis] = host_location[mesh_axis] // num_devices_per_chunk
|
449 |
+
|
450 |
+
def get_local_chunk_info(
|
451 |
+
self, global_shape: Tuple[int, ...], mesh_axes: Sequence[Optional[str]]
|
452 |
+
) -> LocalChunkInfo:
|
453 |
+
"""Get the local chunk info for a given array shape and sharded axes.
|
454 |
+
|
455 |
+
Args:
|
456 |
+
global_shape: the global, unsharded shape of the array to chunk.
|
457 |
+
mesh_axes: a sequence of names (or None) of equal rank to `global_shape`
|
458 |
+
that specifies which mesh dimensions the array is sharded along.
|
459 |
+
|
460 |
+
Returns:
|
461 |
+
LocalChunkInfo containing the logical slices of the array found on this
|
462 |
+
host's local devices, as well as the replica index for this chunk among
|
463 |
+
chunks with the same slice. The latter is used to determine which
|
464 |
+
host should write this chunk during checkpointing.
|
465 |
+
"""
|
466 |
+
local_slice = [slice(None) for dim in global_shape]
|
467 |
+
sharded_mesh_axes = set()
|
468 |
+
for i, (mesh_axis, size) in enumerate(zip(mesh_axes, global_shape)):
|
469 |
+
if not mesh_axis:
|
470 |
+
continue
|
471 |
+
sharded_mesh_axes.add(mesh_axis)
|
472 |
+
if not isinstance(mesh_axis, str):
|
473 |
+
raise NotImplementedError("TODO(jekbradbury)")
|
474 |
+
chunk_id = self.chunk_ids[mesh_axis]
|
475 |
+
chunk_size = size // self.num_chunks[mesh_axis]
|
476 |
+
local_slice[i] = slice(chunk_id * chunk_size, (chunk_id + 1) * chunk_size)
|
477 |
+
|
478 |
+
replicated_mesh_axes = [mesh_axis for mesh_axis in self.mesh_axes if mesh_axis not in sharded_mesh_axes]
|
479 |
+
replica_id = 0
|
480 |
+
for mesh_axis in replicated_mesh_axes:
|
481 |
+
chunk_id = self.chunk_ids[mesh_axis]
|
482 |
+
replica_id = replica_id * self.num_chunks[mesh_axis] + chunk_id
|
483 |
+
|
484 |
+
return LocalChunkInfo(tuple(local_slice), replica_id)
|
485 |
+
|
486 |
+
|
487 |
+
def standard_logical_axis_rules(
|
488 |
+
activation_partitioning_dims: int = 1,
|
489 |
+
parameter_partitioning_dims: int = 1,
|
490 |
+
additional_rules: Optional[LogicalAxisRules] = None,
|
491 |
+
) -> LogicalAxisRules:
|
492 |
+
"""Default sharding rules for T5X model in terms of logical axis names.
|
493 |
+
|
494 |
+
Args:
|
495 |
+
activation_partitioning_dims: enables 2-D activation sharding when set to 2.
|
496 |
+
parameter_partitioning_dims: enables 2-D parameter sharding when set to 2.
|
497 |
+
additional_rules: additional rules (a sequence of tuples) that will be
|
498 |
+
appended to the standard rules.
|
499 |
+
|
500 |
+
Returns:
|
501 |
+
Sequence of logical axis rules
|
502 |
+
"""
|
503 |
+
logging.info(
|
504 |
+
"`activation_partitioning_dims` = %d, `parameter_partitioning_dims` = %d",
|
505 |
+
activation_partitioning_dims,
|
506 |
+
parameter_partitioning_dims,
|
507 |
+
)
|
508 |
+
|
509 |
+
if activation_partitioning_dims == 1 and parameter_partitioning_dims == 1:
|
510 |
+
rules = [
|
511 |
+
("batch", "data"),
|
512 |
+
("vocab", "model"),
|
513 |
+
("embed", None),
|
514 |
+
("mlp", "model"),
|
515 |
+
("heads", "model"),
|
516 |
+
("kv", None),
|
517 |
+
("joined_kv", "model"), # joined heads+kv dim in 2D attn param layouts
|
518 |
+
]
|
519 |
+
elif activation_partitioning_dims == 2 and parameter_partitioning_dims == 1:
|
520 |
+
rules = [
|
521 |
+
("batch", "data"),
|
522 |
+
("vocab", "model"),
|
523 |
+
("mlp", "model"),
|
524 |
+
("heads", "model"),
|
525 |
+
("kv", None),
|
526 |
+
("joined_kv", "model"),
|
527 |
+
("embed", "model"),
|
528 |
+
]
|
529 |
+
elif activation_partitioning_dims == 1 and parameter_partitioning_dims == 2:
|
530 |
+
rules = [
|
531 |
+
("batch", "data"),
|
532 |
+
("vocab", "model"),
|
533 |
+
("mlp", "model"),
|
534 |
+
("heads", "model"),
|
535 |
+
("kv", None),
|
536 |
+
("joined_kv", "model"),
|
537 |
+
("embed", "data"),
|
538 |
+
]
|
539 |
+
elif activation_partitioning_dims == 2 and parameter_partitioning_dims == 2:
|
540 |
+
rules = [
|
541 |
+
("batch", "data"),
|
542 |
+
("vocab", "model"),
|
543 |
+
("mlp", "model"),
|
544 |
+
("heads", "model"),
|
545 |
+
("kv", None),
|
546 |
+
("joined_kv", "model"),
|
547 |
+
("embed", "model"),
|
548 |
+
("embed", "data"),
|
549 |
+
]
|
550 |
+
else:
|
551 |
+
raise ValueError(
|
552 |
+
f"`activation_partitioning_dims` = {activation_partitioning_dims} "
|
553 |
+
f"`parameter_partitioning_dims` = {parameter_partitioning_dims} "
|
554 |
+
"is not supported."
|
555 |
+
)
|
556 |
+
|
557 |
+
# Add the common rules for the replicated logical axes names.
|
558 |
+
replicated_rules = [
|
559 |
+
("relpos_buckets", None),
|
560 |
+
("abspos_buckets", None),
|
561 |
+
("length", None),
|
562 |
+
("layers", None),
|
563 |
+
("stack", None),
|
564 |
+
("mlp_activations", None),
|
565 |
+
]
|
566 |
+
rules.extend(replicated_rules)
|
567 |
+
|
568 |
+
if additional_rules:
|
569 |
+
rules.extend(additional_rules)
|
570 |
+
|
571 |
+
return rules
|
572 |
+
|
573 |
+
|
574 |
+
# NB: This needs to be top-level for the jax compilation cache.
|
575 |
+
def _id_fn(x, ix):
|
576 |
+
"""Identity function for copying parameters to the devices, sharded."""
|
577 |
+
# A pure identity such as `lambda x, *: x` can get optimized away, so we
|
578 |
+
# include a random.split as a cheap function that cannot be optimized away.
|
579 |
+
y = random.split(random.PRNGKey(jnp.array(ix, dtype=jnp.uint32)))
|
580 |
+
return x, y
|
581 |
+
|
582 |
+
|
583 |
+
@dataclasses.dataclass
|
584 |
+
class DataLayout:
|
585 |
+
"""Represents data layout for the partitioned model."""
|
586 |
+
|
587 |
+
batch_size: int
|
588 |
+
shard_id: int
|
589 |
+
num_shards: int
|
590 |
+
is_first_host_in_replica_set: bool
|
591 |
+
|
592 |
+
|
593 |
+
PartitionedCallable = Callable[..., Any]
|
594 |
+
CompiledPartitionedCallable = Callable[..., Any]
|
595 |
+
|
596 |
+
|
597 |
+
class BasePartitioner(metaclass=abc.ABCMeta):
|
598 |
+
"""Interface for partitioning computations across hardware devices."""
|
599 |
+
|
600 |
+
def __init__(
|
601 |
+
self,
|
602 |
+
num_partitions: Optional[int] = None,
|
603 |
+
model_parallel_submesh: Optional[HardwareMesh] = None,
|
604 |
+
params_on_devices: bool = True,
|
605 |
+
backend: Optional[str] = None,
|
606 |
+
):
|
607 |
+
"""Configures the partitioner.
|
608 |
+
|
609 |
+
Args:
|
610 |
+
num_partitions: the number of partitions to use. Ignored if
|
611 |
+
`model_parallel_submesh` is provided.
|
612 |
+
model_parallel_submesh: 4-tuple that specifies the x,y,z,c submesh to use
|
613 |
+
as the model-parallel device tile. This submesh is used for the larger
|
614 |
+
of the two parameter dimensions, and, if 2-D activation sharding is
|
615 |
+
enabled, for the model dimension of activations. The rest of the mesh is
|
616 |
+
used for data parallelism and, if 2-D parameter sharding is enabled, the
|
617 |
+
other parameter dimension.
|
618 |
+
params_on_devices: whether to keep the params on devices, if False -
|
619 |
+
params stay in the host memory. Note that some partitioners might ignore
|
620 |
+
this setting, for example if they don't support storing all params on
|
621 |
+
device memory.
|
622 |
+
backend: get devices from the pinned backend, if specified. This is useful
|
623 |
+
for explicitly specifying the devices other than relying on
|
624 |
+
jax_platform_name.
|
625 |
+
"""
|
626 |
+
|
627 |
+
if not num_partitions and not model_parallel_submesh:
|
628 |
+
raise ValueError("At least one of `num_partitions` or " "`model_parallel_submesh` must be set.")
|
629 |
+
|
630 |
+
if model_parallel_submesh is not None and len(model_parallel_submesh) != 4:
|
631 |
+
logging.error(
|
632 |
+
(
|
633 |
+
"`model_parallel_submesh` must be either None or a 4-tuple. Got"
|
634 |
+
" `model_parallel_submesh`=%s. A ValueError will be raised"
|
635 |
+
" beginning March 1, 2022."
|
636 |
+
),
|
637 |
+
model_parallel_submesh,
|
638 |
+
)
|
639 |
+
|
640 |
+
if bool(num_partitions) and bool(model_parallel_submesh):
|
641 |
+
logging.error(
|
642 |
+
(
|
643 |
+
"At most one of `num_partitions` or `model_parallel_submesh` can be"
|
644 |
+
" set. Got `num_partitions=%s` and `model_parallel_submesh`=%s. A"
|
645 |
+
" ValueError will be raised beginning March 21, 2022."
|
646 |
+
),
|
647 |
+
num_partitions,
|
648 |
+
model_parallel_submesh,
|
649 |
+
)
|
650 |
+
|
651 |
+
self._num_partitions = num_partitions
|
652 |
+
self._model_parallel_submesh = model_parallel_submesh
|
653 |
+
self._params_on_devices = params_on_devices
|
654 |
+
self._data_axis = "data"
|
655 |
+
self._backend = backend
|
656 |
+
|
657 |
+
@property
|
658 |
+
def mesh(self) -> Mesh:
|
659 |
+
raise NotImplementedError
|
660 |
+
|
661 |
+
@property
|
662 |
+
def data_partition_spec(self) -> PartitionSpec:
|
663 |
+
return PartitionSpec(self._data_axis)
|
664 |
+
|
665 |
+
def get_data_layout(self, batch_size: Optional[int] = None, host_index: Optional[int] = None) -> DataLayout:
|
666 |
+
"""Returns filled `DataLayout` based on the partitioned model layout.
|
667 |
+
|
668 |
+
Args:
|
669 |
+
batch_size: if set, indicates the requested batch size. The exception will
|
670 |
+
be raised if this batch size is not compatible with the layout. If not
|
671 |
+
set, the batch size is inferred from the layout.
|
672 |
+
host_index: indicates the host index to use for the calculations, if not
|
673 |
+
set - use JAX-provided one. Should be in [0, num_hosts) interval and the
|
674 |
+
order should match the order of corresponding CPU devices in
|
675 |
+
`jax.devices()`.
|
676 |
+
|
677 |
+
Returns:
|
678 |
+
Filled `DataLayout` structure.
|
679 |
+
"""
|
680 |
+
if host_index is not None:
|
681 |
+
raise NotImplementedError("Explicit host_index is not yet implemented.")
|
682 |
+
if self._data_axis is None:
|
683 |
+
return DataLayout(
|
684 |
+
batch_size=batch_size,
|
685 |
+
shard_id=0,
|
686 |
+
num_shards=1,
|
687 |
+
is_first_host_in_replica_set=(jax.process_index() == 0),
|
688 |
+
)
|
689 |
+
mesh_size = self._local_chunker.global_mesh.shape[self._data_axis]
|
690 |
+
batch_size = batch_size or mesh_size
|
691 |
+
if batch_size % mesh_size:
|
692 |
+
raise ValueError(
|
693 |
+
f"Batch size ({batch_size}) must be divisible by corresponding " f"mesh size ({mesh_size})."
|
694 |
+
)
|
695 |
+
num_shards = self._local_chunker.num_chunks[self._data_axis]
|
696 |
+
if batch_size % num_shards:
|
697 |
+
raise ValueError(f"Batch size ({batch_size}) must be divisible by number of " f"replicas ({num_shards}).")
|
698 |
+
replica_id = self._local_chunker.get_local_chunk_info((batch_size,), [self._data_axis]).replica_id
|
699 |
+
return DataLayout(
|
700 |
+
batch_size=int(batch_size),
|
701 |
+
shard_id=int(self._local_chunker.chunk_ids[self._data_axis]),
|
702 |
+
num_shards=int(num_shards),
|
703 |
+
is_first_host_in_replica_set=(replica_id == 0),
|
704 |
+
)
|
705 |
+
|
706 |
+
def get_local_chunk_info(
|
707 |
+
self, global_shape: Tuple[int, ...], mesh_axes: Sequence[Optional[str]]
|
708 |
+
) -> LocalChunkInfo:
|
709 |
+
"""Returns the local chunk info for a given array shape and sharded axes."""
|
710 |
+
return self._local_chunker.get_local_chunk_info(global_shape, mesh_axes)
|
711 |
+
|
712 |
+
@property
|
713 |
+
def params_on_devices(self):
|
714 |
+
return self._params_on_devices
|
715 |
+
|
716 |
+
def move_params_to_devices(self, train_state: TrainState, train_state_axes: TrainState) -> TrainState:
|
717 |
+
"""Moves the optimizer parameters to devices."""
|
718 |
+
p_id_fn = self.partition(
|
719 |
+
_id_fn,
|
720 |
+
in_axis_resources=(train_state_axes, None),
|
721 |
+
out_axis_resources=(train_state_axes, None),
|
722 |
+
donate_argnums=(0,),
|
723 |
+
)
|
724 |
+
if jax.config.jax_array and jax.process_count() > 1:
|
725 |
+
train_state = multihost_utils.host_local_array_to_global_array(train_state, self.mesh, train_state_axes)
|
726 |
+
train_state, _ = p_id_fn(train_state, jnp.ones((), dtype=jnp.uint32))
|
727 |
+
return train_state
|
728 |
+
|
729 |
+
@property
|
730 |
+
@abc.abstractmethod
|
731 |
+
def _local_chunker(self):
|
732 |
+
"""Returns the chunker that matches the parameters of this partitioner."""
|
733 |
+
raise NotImplementedError
|
734 |
+
|
735 |
+
def get_logical_axes(self, train_state: TrainState) -> TrainState:
|
736 |
+
"""Returns a copy of TrainState with Optional[AxisNames] as leaves."""
|
737 |
+
# By default, return None for the logical axes.
|
738 |
+
return train_state.restore_state(jax.tree_map(lambda x: None, train_state.state_dict()))
|
739 |
+
|
740 |
+
def get_mesh_axes(self, train_state: TrainState) -> TrainState:
|
741 |
+
"""Returns a copy of TrainState with Optional[PartitionSpecs] as leaves."""
|
742 |
+
raise NotImplementedError
|
743 |
+
|
744 |
+
@abc.abstractmethod
|
745 |
+
def partition(
|
746 |
+
self,
|
747 |
+
fn: Callable, # pylint: disable=g-bare-generic
|
748 |
+
in_axis_resources,
|
749 |
+
out_axis_resources,
|
750 |
+
static_argnums: Union[int, Sequence[int]] = (),
|
751 |
+
donate_argnums: Union[int, Sequence[int]] = (),
|
752 |
+
) -> PartitionedCallable:
|
753 |
+
"""Partitions the computation using partitioner-specific implementation.
|
754 |
+
|
755 |
+
Args:
|
756 |
+
fn: the function to partition.
|
757 |
+
in_axis_resources: Pytree of structure matching that of arguments to `fn`,
|
758 |
+
with all actual arguments replaced by resource assignment
|
759 |
+
specifications. It is also valid to specify a pytree prefix (e.g. one
|
760 |
+
value in place of a whole subtree), in which case the leaves get
|
761 |
+
broadcast to all values in that subtree.
|
762 |
+
The valid resource assignment specifications are:
|
763 |
+
`None`: in which case the value will be replicated on all devices
|
764 |
+
`PartitionSpec`: a tuple of length at most equal to the rank of the
|
765 |
+
partitioned value. Each element can be a `None`, a mesh axis or a
|
766 |
+
tuple of mesh axes, and specifies the set of resources assigned to
|
767 |
+
partition the value's dimension matching its position in the spec.
|
768 |
+
out_axis_resources: Like `in_axis_resources`, but specifies resource
|
769 |
+
assignment for function outputs.
|
770 |
+
static_argnums: an optional int or collection of ints that specify which
|
771 |
+
positional arguments to treat as static (compile-time constant) in the
|
772 |
+
partitioned function.
|
773 |
+
donate_argnums: an optional int or collection of ints that specify which
|
774 |
+
argument buffers are "donated" to the computation. It is safe to donate
|
775 |
+
argument buffers if you no longer need them once the computation has
|
776 |
+
finished.
|
777 |
+
|
778 |
+
Returns:
|
779 |
+
A partitioned version of the input function.
|
780 |
+
"""
|
781 |
+
raise NotImplementedError
|
782 |
+
|
783 |
+
@abc.abstractmethod
|
784 |
+
def compile(self, partitioned_fn: PartitionedCallable, *args) -> CompiledPartitionedCallable:
|
785 |
+
"""Compiles and returns the partitioned function, or the original.
|
786 |
+
|
787 |
+
Args:
|
788 |
+
partitioned_fn: The partitioned function.
|
789 |
+
*args: Sample arguments to the partitioned function matching the input
|
790 |
+
shapes that will be passed to the compiled function.
|
791 |
+
|
792 |
+
Returns:
|
793 |
+
The compiled function, or the original if this partitioner does not
|
794 |
+
support compilation.
|
795 |
+
"""
|
796 |
+
raise NotImplementedError
|
797 |
+
|
798 |
+
|
799 |
+
class PjittedFnWithContext(PartitionedCallable):
|
800 |
+
"""Wraps pjitted function to apply the appropriate contexts."""
|
801 |
+
|
802 |
+
def __init__(
|
803 |
+
self,
|
804 |
+
pjitted_fn,
|
805 |
+
partition_mesh: Mesh,
|
806 |
+
logical_axis_rules: flax_partitioning.LogicalRules = (),
|
807 |
+
):
|
808 |
+
self._pjitted_fn = pjitted_fn
|
809 |
+
self._mesh = partition_mesh
|
810 |
+
self._logical_axis_rules = logical_axis_rules
|
811 |
+
|
812 |
+
def __call__(self, *args):
|
813 |
+
with Mesh(self._mesh.devices, self._mesh.axis_names), flax_partitioning.axis_rules(self._logical_axis_rules):
|
814 |
+
return self._pjitted_fn(*args)
|
815 |
+
|
816 |
+
def lower(self, *args):
|
817 |
+
with Mesh(self._mesh.devices, self._mesh.axis_names), flax_partitioning.axis_rules(self._logical_axis_rules):
|
818 |
+
return self._pjitted_fn.lower(*args)
|
819 |
+
|
820 |
+
|
821 |
+
class BasePjitPartitioner(BasePartitioner):
|
822 |
+
"""Partitioner that uses T5X version of jax.pjit."""
|
823 |
+
|
824 |
+
@cached_property
|
825 |
+
def _local_chunker(self) -> LocalChunker:
|
826 |
+
return LocalChunker(self.mesh)
|
827 |
+
|
828 |
+
@cached_property
|
829 |
+
def mesh(self) -> Mesh:
|
830 |
+
return default_mesh(self._num_partitions, self._model_parallel_submesh, self._backend)
|
831 |
+
|
832 |
+
def partition(
|
833 |
+
self,
|
834 |
+
fn: Callable, # pylint: disable=g-bare-generic
|
835 |
+
in_axis_resources,
|
836 |
+
out_axis_resources,
|
837 |
+
static_argnums: Union[int, Sequence[int]] = (),
|
838 |
+
donate_argnums: Union[int, Sequence[int]] = (),
|
839 |
+
) -> PjittedFnWithContext:
|
840 |
+
pjitted = pjit(
|
841 |
+
fn,
|
842 |
+
in_axis_resources=in_axis_resources,
|
843 |
+
out_axis_resources=out_axis_resources,
|
844 |
+
static_argnums=static_argnums,
|
845 |
+
donate_argnums=donate_argnums,
|
846 |
+
backend=self._backend,
|
847 |
+
)
|
848 |
+
|
849 |
+
return PjittedFnWithContext(pjitted, self.mesh)
|
850 |
+
|
851 |
+
def compile(self, partitioned_fn: PjittedFnWithContext, *args) -> CompiledPartitionedCallable:
|
852 |
+
return partitioned_fn.lower(*args).compile()
|
853 |
+
|
854 |
+
|
855 |
+
class PjitPartitioner(BasePjitPartitioner):
|
856 |
+
"""Partitioner that uses named axes and jax.pjit."""
|
857 |
+
|
858 |
+
def __init__(
|
859 |
+
self,
|
860 |
+
num_partitions: Optional[int] = None,
|
861 |
+
model_parallel_submesh: Optional[HardwareMesh] = None,
|
862 |
+
params_on_devices: bool = True,
|
863 |
+
backend: Optional[str] = None,
|
864 |
+
logical_axis_rules: Optional[LogicalAxisRules] = None,
|
865 |
+
use_cpu_pjit: Optional[bool] = False,
|
866 |
+
):
|
867 |
+
"""PjitPartitioner constructor.
|
868 |
+
|
869 |
+
See https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx/usage/partitioning for details.
|
870 |
+
|
871 |
+
Args:
|
872 |
+
num_partitions: an integer that specifies the size of the model parallel
|
873 |
+
submesh to be automatically selected for the current topology. See
|
874 |
+
`model_parallel_submesh` for details on how this submesh is used.
|
875 |
+
Mutually exlusive with `model_parallel_submesh`.
|
876 |
+
model_parallel_submesh: is a 4-tuple that specifies the `(x, y, z, c)`
|
877 |
+
submesh model-parallel device tile, an axis of accelerator parallelism
|
878 |
+
orthogonal to data parallelism. Array axes in a model's parameters or
|
879 |
+
activations can be sharded over this submesh using axis rules (see
|
880 |
+
`logical_axis_rules`) that map them to 'model'. The effective number of
|
881 |
+
model sub-partitions is equal to `np.prod(model_parallel_submesh)` and
|
882 |
+
must evenly divide the total number of devices (i.e.,
|
883 |
+
`jax.device_count() % np.prod(model_parallel_submesh) == 0`). The rest
|
884 |
+
of the TPU mesh is the data parallel submesh, providing
|
885 |
+
`jax.device_count() // np.prod(model_parallel_submesh)` partitions. It
|
886 |
+
is used for data (batch) parallelism and to shard other array axes that
|
887 |
+
are mapped to 'data'. This argument is mutually exclusive with
|
888 |
+
`num_partitions`.
|
889 |
+
params_on_devices: whether to keep the params on devices, if False -
|
890 |
+
params stay in the host memory. Note that some partitioners might ignore
|
891 |
+
this setting, for example if they don't support storing all params on
|
892 |
+
device memory.
|
893 |
+
backend: get devices from the pinned backend, if specified. This is
|
894 |
+
useful for explicitly specifying the devices other than relying on
|
895 |
+
jax_platform_name.
|
896 |
+
logical_axis_rules: a priority-ordered sequence of KV tuples that maps
|
897 |
+
logical axis names to either `None` (not sharded), 'model' (to shard
|
898 |
+
across the model-parallel submesh), or 'data' (to shard across the
|
899 |
+
data-parallel submesh).
|
900 |
+
use_cpu_pjit: enables wrapper function for pjit which just jits the
|
901 |
+
function if using CPU backend.
|
902 |
+
"""
|
903 |
+
super().__init__(
|
904 |
+
num_partitions=num_partitions,
|
905 |
+
model_parallel_submesh=model_parallel_submesh,
|
906 |
+
params_on_devices=params_on_devices,
|
907 |
+
backend=backend,
|
908 |
+
)
|
909 |
+
if logical_axis_rules is None:
|
910 |
+
logical_axis_rules = standard_logical_axis_rules()
|
911 |
+
self._logical_axis_rules = tuple(logical_axis_rules)
|
912 |
+
(self._data_axis,) = flax_partitioning.logical_to_mesh_axes(["batch"], logical_axis_rules)
|
913 |
+
self._use_cpu_pjit = use_cpu_pjit
|
914 |
+
|
915 |
+
def partition(
|
916 |
+
self,
|
917 |
+
fn: Callable, # pylint: disable=g-bare-generic
|
918 |
+
in_axis_resources,
|
919 |
+
out_axis_resources,
|
920 |
+
static_argnums: Union[int, Sequence[int]] = (),
|
921 |
+
donate_argnums: Union[int, Sequence[int]] = (),
|
922 |
+
) -> PjittedFnWithContext:
|
923 |
+
"""Partitions the function using jax.pjit."""
|
924 |
+
if self._use_cpu_pjit:
|
925 |
+
pjit_fn = pjit_with_cpu_fallback
|
926 |
+
else:
|
927 |
+
pjit_fn = pjit
|
928 |
+
pjitted = pjit_fn(
|
929 |
+
fn,
|
930 |
+
in_axis_resources=in_axis_resources,
|
931 |
+
out_axis_resources=out_axis_resources,
|
932 |
+
static_argnums=static_argnums,
|
933 |
+
donate_argnums=donate_argnums,
|
934 |
+
backend=self._backend,
|
935 |
+
)
|
936 |
+
|
937 |
+
return PjittedFnWithContext(pjitted, self.mesh, self._logical_axis_rules)
|
938 |
+
|
939 |
+
@property
|
940 |
+
def logical_axis_rules(self):
|
941 |
+
"""Returns the logical axis rules."""
|
942 |
+
return self._logical_axis_rules
|
943 |
+
|
944 |
+
def get_logical_axes(self, train_state: TrainState) -> TrainState:
|
945 |
+
"""Returns a copy of TrainState with Optional[AxisNames] as leaves."""
|
946 |
+
return train_state.as_logical_axes()
|
947 |
+
|
948 |
+
def get_mesh_axes(self, train_state: TrainState) -> TrainState:
|
949 |
+
"""Returns a copy of TrainState with Optional[PartitionSpecs] as leaves."""
|
950 |
+
logical_axes = self.get_logical_axes(train_state)
|
951 |
+
|
952 |
+
def _logical_to_mesh_axes(param_name, logical_axes):
|
953 |
+
if logical_axes is None:
|
954 |
+
return None
|
955 |
+
elif logical_axes is traverse_util.empty_node:
|
956 |
+
return traverse_util.empty_node
|
957 |
+
try:
|
958 |
+
return flax_partitioning.logical_to_mesh_axes(logical_axes, self._logical_axis_rules)
|
959 |
+
except ValueError as e:
|
960 |
+
raise ValueError(f"Failed to map logical axes for {param_name}") from e
|
961 |
+
|
962 |
+
flat_logical_axes = traverse_util.flatten_dict(logical_axes.state_dict(), keep_empty_nodes=True, sep="/")
|
963 |
+
flat_mesh_axes = {k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()}
|
964 |
+
|
965 |
+
return logical_axes.restore_state(traverse_util.unflatten_dict(flat_mesh_axes, sep="/"))
|
flax/distil_whisper/pipeline.py
ADDED
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Whisper JAX pipeline compatible with Distil Whisper checkpoints. Copied from https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py"""
|
17 |
+
|
18 |
+
import math
|
19 |
+
|
20 |
+
import jax
|
21 |
+
import jax.numpy as jnp
|
22 |
+
import numpy as np
|
23 |
+
import requests
|
24 |
+
import torch
|
25 |
+
from flax import jax_utils
|
26 |
+
from flax.core.frozen_dict import freeze
|
27 |
+
from flax.training.common_utils import shard
|
28 |
+
from transformers import WhisperFeatureExtractor, WhisperTokenizerFast
|
29 |
+
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
|
30 |
+
from transformers.pipelines.audio_utils import ffmpeg_read
|
31 |
+
from transformers.utils import logging
|
32 |
+
|
33 |
+
from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration
|
34 |
+
|
35 |
+
|
36 |
+
logger = logging.get_logger(__name__)
|
37 |
+
|
38 |
+
|
39 |
+
class FlaxWhisperFeatureExtractor(WhisperFeatureExtractor):
|
40 |
+
def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
|
41 |
+
"""
|
42 |
+
Compute the log-mel spectrogram of the provided audio using torch filters. Using the torch implementation
|
43 |
+
computes stft filter banks approx 5x faster than its numpy counterpart, which is the native implementation
|
44 |
+
in transformers, and matches to within 1e-5 abs tolerance.
|
45 |
+
"""
|
46 |
+
waveform = torch.from_numpy(waveform).type(torch.float32)
|
47 |
+
|
48 |
+
window = torch.hann_window(self.n_fft)
|
49 |
+
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
|
50 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
51 |
+
|
52 |
+
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
|
53 |
+
mel_spec = mel_filters.T @ magnitudes
|
54 |
+
|
55 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
56 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
57 |
+
log_spec = (log_spec + 4.0) / 4.0
|
58 |
+
return log_spec.numpy()
|
59 |
+
|
60 |
+
|
61 |
+
class FlaxWhisperPipeline:
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
checkpoint="openai/whisper-large-v2",
|
65 |
+
dtype=jnp.float32,
|
66 |
+
batch_size=None,
|
67 |
+
max_length=None,
|
68 |
+
**kwargs,
|
69 |
+
):
|
70 |
+
"""
|
71 |
+
Args
|
72 |
+
checkpoint (`str`, *optional*, defaults to `"openai/whisper-large-v2"):
|
73 |
+
The Whisper checkpoint to use with the pipeline. Must be an available checkpoint on the Hugging Face Hub
|
74 |
+
with Flax weights.
|
75 |
+
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
76 |
+
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
77 |
+
`jax.numpy.bfloat16` (on TPUs). This can be used to enable half-precision inference on GPUs or TPUs.
|
78 |
+
If specified all the computation will be performed with the given `dtype`. **Note that this only
|
79 |
+
specifies the dtype of the computation and does not influence the dtype of model parameters.**
|
80 |
+
batch_size (`int`, *optional*, defaults to the minimum per-device batch size, i.e. `jax.local_device_count()`):
|
81 |
+
The batch size to be used in chunking transcription. Beneficial for transcribing long audio files. Passing
|
82 |
+
a batch size in the `__init__` method will be superseded by any batch size passed to the `__call__` method.
|
83 |
+
max_length (`int`, *optional*):
|
84 |
+
The maximum numbers of tokens to generate. Defaults to `model.config.max_length`.
|
85 |
+
"""
|
86 |
+
self.checkpoint = checkpoint
|
87 |
+
self.dtype = dtype
|
88 |
+
|
89 |
+
self.feature_extractor = FlaxWhisperFeatureExtractor.from_pretrained(self.checkpoint)
|
90 |
+
self.tokenizer = WhisperTokenizerFast.from_pretrained(self.checkpoint)
|
91 |
+
|
92 |
+
self.model, self.params = FlaxWhisperForConditionalGeneration.from_pretrained(
|
93 |
+
self.checkpoint,
|
94 |
+
_do_init=False,
|
95 |
+
dtype=self.dtype,
|
96 |
+
**kwargs,
|
97 |
+
)
|
98 |
+
|
99 |
+
self.max_length = max_length if max_length is not None else self.model.generation_config.max_length
|
100 |
+
self.min_batch_size = jax.local_device_count()
|
101 |
+
self.batch_size = (
|
102 |
+
batch_size if batch_size is not None else self.min_batch_size
|
103 |
+
) # we need a minimum of 1 batch per-device
|
104 |
+
|
105 |
+
def generate(
|
106 |
+
params,
|
107 |
+
input_features,
|
108 |
+
forced_decoder_ids,
|
109 |
+
return_timestamps,
|
110 |
+
num_beams,
|
111 |
+
length_penalty,
|
112 |
+
do_sample,
|
113 |
+
top_k,
|
114 |
+
temperature,
|
115 |
+
):
|
116 |
+
output_ids = self.model.pipeline_generate(
|
117 |
+
input_features,
|
118 |
+
params=params,
|
119 |
+
forced_decoder_ids=forced_decoder_ids,
|
120 |
+
return_timestamps=return_timestamps,
|
121 |
+
max_length=self.max_length,
|
122 |
+
num_beams=num_beams,
|
123 |
+
length_penalty=length_penalty,
|
124 |
+
do_sample=do_sample,
|
125 |
+
top_k=top_k,
|
126 |
+
temperature=temperature,
|
127 |
+
)
|
128 |
+
return output_ids
|
129 |
+
|
130 |
+
self.params = jax_utils.replicate(self.params)
|
131 |
+
self.p_generate = jax.pmap(
|
132 |
+
generate,
|
133 |
+
"input_features",
|
134 |
+
in_axes=(0, 0, None, None, None, None, None, None, None),
|
135 |
+
static_broadcasted_argnums=(
|
136 |
+
3,
|
137 |
+
4,
|
138 |
+
5,
|
139 |
+
6,
|
140 |
+
7,
|
141 |
+
8,
|
142 |
+
),
|
143 |
+
)
|
144 |
+
|
145 |
+
def generate(
|
146 |
+
self,
|
147 |
+
input_features,
|
148 |
+
language=None,
|
149 |
+
task=None,
|
150 |
+
return_timestamps=False,
|
151 |
+
num_beams=1,
|
152 |
+
length_penalty=1.0,
|
153 |
+
do_sample=False,
|
154 |
+
top_k=50,
|
155 |
+
temperature=1.0,
|
156 |
+
):
|
157 |
+
forced_decoder_ids = self.get_forced_decoder_ids(
|
158 |
+
language=language, task=task, return_timestamps=return_timestamps
|
159 |
+
)
|
160 |
+
# if we're using pmap we need to manually replicate the input data across devices and gather the output tokens
|
161 |
+
output_ids = self.p_generate(
|
162 |
+
freeze(self.params),
|
163 |
+
shard(input_features),
|
164 |
+
forced_decoder_ids,
|
165 |
+
return_timestamps,
|
166 |
+
num_beams,
|
167 |
+
length_penalty,
|
168 |
+
do_sample,
|
169 |
+
top_k,
|
170 |
+
temperature,
|
171 |
+
).sequences
|
172 |
+
output_ids = jax.device_get(output_ids.reshape(-1, self.max_length))
|
173 |
+
return output_ids
|
174 |
+
|
175 |
+
def get_forced_decoder_ids(self, generation_config=None, task=None, language=None, return_timestamps=False):
|
176 |
+
if generation_config is None:
|
177 |
+
generation_config = self.model.generation_config
|
178 |
+
|
179 |
+
if hasattr(generation_config, "is_multilingual"):
|
180 |
+
is_multilingual = generation_config.is_multilingual
|
181 |
+
else:
|
182 |
+
is_multilingual = None
|
183 |
+
|
184 |
+
forced_decoder_ids = []
|
185 |
+
|
186 |
+
if is_multilingual:
|
187 |
+
if language is not None:
|
188 |
+
language = language.lower()
|
189 |
+
if language in generation_config.lang_to_id.keys():
|
190 |
+
language_token = language
|
191 |
+
elif language in TO_LANGUAGE_CODE.values():
|
192 |
+
language_token = f"<|{language}|>"
|
193 |
+
elif language in TO_LANGUAGE_CODE.keys():
|
194 |
+
language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
|
195 |
+
else:
|
196 |
+
if len(language) == 2:
|
197 |
+
# ISO 639-1 language code
|
198 |
+
acceptable_languages = list(TO_LANGUAGE_CODE.values())
|
199 |
+
elif "<" in language or "|" in language or ">" in language:
|
200 |
+
# generation config language code
|
201 |
+
acceptable_languages = list(generation_config.lang_to_id.keys())
|
202 |
+
else:
|
203 |
+
# language passed as a string
|
204 |
+
acceptable_languages = list(TO_LANGUAGE_CODE.keys())
|
205 |
+
raise ValueError(
|
206 |
+
f"Unsupported language: {language}. Language should be one of:" f" {acceptable_languages}."
|
207 |
+
)
|
208 |
+
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
|
209 |
+
|
210 |
+
if task is not None:
|
211 |
+
forced_decoder_ids.append((2, generation_config.task_to_id[task]))
|
212 |
+
else:
|
213 |
+
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
|
214 |
+
|
215 |
+
if not return_timestamps:
|
216 |
+
if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
|
217 |
+
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
|
218 |
+
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
|
219 |
+
else:
|
220 |
+
forced_decoder_ids.append((1, generation_config.no_timestamps_token_id))
|
221 |
+
|
222 |
+
return forced_decoder_ids
|
223 |
+
|
224 |
+
def chunk_iter_with_batch(self, inputs, chunk_len, stride_left, stride_right, batch_size):
|
225 |
+
inputs_len = inputs.shape[0]
|
226 |
+
step = chunk_len - stride_left - stride_right
|
227 |
+
|
228 |
+
all_chunk_start_idx = np.arange(0, inputs_len, step)
|
229 |
+
num_samples = len(all_chunk_start_idx)
|
230 |
+
|
231 |
+
num_batches = math.ceil(num_samples / batch_size)
|
232 |
+
batch_idx = np.array_split(np.arange(num_samples), num_batches)
|
233 |
+
|
234 |
+
for idx in batch_idx:
|
235 |
+
chunk_start_idx = all_chunk_start_idx[idx]
|
236 |
+
|
237 |
+
chunk_end_idx = chunk_start_idx + chunk_len
|
238 |
+
|
239 |
+
chunks = [inputs[chunk_start:chunk_end] for chunk_start, chunk_end in zip(chunk_start_idx, chunk_end_idx)]
|
240 |
+
processed = self.feature_extractor(
|
241 |
+
chunks, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
|
242 |
+
)
|
243 |
+
|
244 |
+
_stride_left = np.where(chunk_start_idx == 0, 0, stride_left)
|
245 |
+
is_last = np.where(stride_right > 0, chunk_end_idx > inputs_len, chunk_end_idx >= inputs_len)
|
246 |
+
_stride_right = np.where(is_last, 0, stride_right)
|
247 |
+
|
248 |
+
chunk_lens = [chunk.shape[0] for chunk in chunks]
|
249 |
+
strides = [
|
250 |
+
(chunk_l, _stride_l, _stride_r)
|
251 |
+
for chunk_l, _stride_l, _stride_r in zip(chunk_lens, _stride_left, _stride_right)
|
252 |
+
]
|
253 |
+
|
254 |
+
yield {"stride": strides, **processed}
|
255 |
+
|
256 |
+
def preprocess_batch(self, inputs, chunk_length_s=30.0, stride_length_s=None, batch_size=None):
|
257 |
+
if isinstance(inputs, np.ndarray):
|
258 |
+
logger.warning(
|
259 |
+
"Numpy array passed as input - no sampling rate checks will be performed."
|
260 |
+
"It is strongly recommended to pass the input as a dictionary with an 'array' key "
|
261 |
+
"containing the numpy array representing the audio, and a 'sampling_rate' key "
|
262 |
+
"containing the sampling rate associated with the audio array."
|
263 |
+
"Failing to do so can result in silent errors that might be hard to debug."
|
264 |
+
)
|
265 |
+
|
266 |
+
if isinstance(inputs, str):
|
267 |
+
if inputs.startswith("http://") or inputs.startswith("https://"):
|
268 |
+
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
269 |
+
# like http_huggingface_co.png
|
270 |
+
inputs = requests.get(inputs).content
|
271 |
+
else:
|
272 |
+
with open(inputs, "rb") as f:
|
273 |
+
inputs = f.read()
|
274 |
+
|
275 |
+
if isinstance(inputs, bytes):
|
276 |
+
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
|
277 |
+
|
278 |
+
stride = None
|
279 |
+
if isinstance(inputs, dict):
|
280 |
+
stride = inputs.get("stride", None)
|
281 |
+
# Accepting `"array"` which is the key defined in `datasets` for
|
282 |
+
# better integration
|
283 |
+
if not ("sampling_rate" in inputs and "array" in inputs):
|
284 |
+
raise ValueError(
|
285 |
+
"When passing a dictionary to FlaxWhisperPipline, the dict needs to contain an 'array' key "
|
286 |
+
"containing the numpy array representing the audio, and a 'sampling_rate' key "
|
287 |
+
"containing the sampling rate associated with the audio array."
|
288 |
+
)
|
289 |
+
|
290 |
+
in_sampling_rate = inputs.get("sampling_rate")
|
291 |
+
inputs = inputs.get("array", None)
|
292 |
+
|
293 |
+
if in_sampling_rate != self.feature_extractor.sampling_rate:
|
294 |
+
try:
|
295 |
+
import librosa
|
296 |
+
except ImportError as err:
|
297 |
+
raise ImportError(
|
298 |
+
"To support resampling audio files, please install 'librosa' and 'soundfile'."
|
299 |
+
) from err
|
300 |
+
|
301 |
+
inputs = librosa.resample(
|
302 |
+
inputs, orig_sr=in_sampling_rate, target_sr=self.feature_extractor.sampling_rate
|
303 |
+
)
|
304 |
+
ratio = self.feature_extractor.sampling_rate / in_sampling_rate
|
305 |
+
else:
|
306 |
+
ratio = 1
|
307 |
+
|
308 |
+
if not isinstance(inputs, np.ndarray):
|
309 |
+
raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
|
310 |
+
if len(inputs.shape) != 1:
|
311 |
+
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
|
312 |
+
|
313 |
+
if stride is not None:
|
314 |
+
if stride[0] + stride[1] > inputs.shape[0]:
|
315 |
+
raise ValueError("Stride is too large for input")
|
316 |
+
|
317 |
+
# Stride needs to get the chunk length here, it's going to get
|
318 |
+
# swallowed by the `feature_extractor` later, and then batching
|
319 |
+
# can add extra data in the inputs, so we need to keep track
|
320 |
+
# of the original length in the stride so we can cut properly.
|
321 |
+
stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
|
322 |
+
|
323 |
+
if chunk_length_s:
|
324 |
+
if stride_length_s is None:
|
325 |
+
stride_length_s = chunk_length_s / 6
|
326 |
+
|
327 |
+
if isinstance(stride_length_s, (int, float)):
|
328 |
+
stride_length_s = [stride_length_s, stride_length_s]
|
329 |
+
|
330 |
+
chunk_len = round(chunk_length_s * self.feature_extractor.sampling_rate)
|
331 |
+
stride_left = round(stride_length_s[0] * self.feature_extractor.sampling_rate)
|
332 |
+
stride_right = round(stride_length_s[1] * self.feature_extractor.sampling_rate)
|
333 |
+
|
334 |
+
if chunk_len < stride_left + stride_right:
|
335 |
+
raise ValueError("Chunk length must be superior to stride length")
|
336 |
+
|
337 |
+
for item in self.chunk_iter_with_batch(
|
338 |
+
inputs,
|
339 |
+
chunk_len,
|
340 |
+
stride_left,
|
341 |
+
stride_right,
|
342 |
+
batch_size,
|
343 |
+
):
|
344 |
+
yield item
|
345 |
+
else:
|
346 |
+
processed = self.feature_extractor(
|
347 |
+
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
|
348 |
+
)
|
349 |
+
if stride is not None:
|
350 |
+
processed["stride"] = stride
|
351 |
+
yield processed
|
352 |
+
|
353 |
+
def postprocess(self, model_outputs, return_timestamps=None, return_language=None):
|
354 |
+
# unpack the outputs from list(dict(list)) to list(dict)
|
355 |
+
model_outputs = [dict(zip(output, t)) for output in model_outputs for t in zip(*output.values())]
|
356 |
+
|
357 |
+
time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions
|
358 |
+
# Send the chunking back to seconds, it's easier to handle in whisper
|
359 |
+
sampling_rate = self.feature_extractor.sampling_rate
|
360 |
+
for output in model_outputs:
|
361 |
+
if "stride" in output:
|
362 |
+
chunk_len, stride_left, stride_right = output["stride"]
|
363 |
+
# Go back in seconds
|
364 |
+
chunk_len /= sampling_rate
|
365 |
+
stride_left /= sampling_rate
|
366 |
+
stride_right /= sampling_rate
|
367 |
+
output["stride"] = chunk_len, stride_left, stride_right
|
368 |
+
|
369 |
+
text, optional = self.tokenizer._decode_asr(
|
370 |
+
model_outputs,
|
371 |
+
return_timestamps=return_timestamps,
|
372 |
+
return_language=return_language,
|
373 |
+
time_precision=time_precision,
|
374 |
+
)
|
375 |
+
return {"text": text, **optional}
|
376 |
+
|
377 |
+
def forward(
|
378 |
+
self,
|
379 |
+
model_inputs,
|
380 |
+
batch_size=None,
|
381 |
+
language=None,
|
382 |
+
task=None,
|
383 |
+
return_timestamps=False,
|
384 |
+
num_beams=1,
|
385 |
+
length_penalty=1.0,
|
386 |
+
do_sample=False,
|
387 |
+
top_k=50,
|
388 |
+
temperature=1.0,
|
389 |
+
):
|
390 |
+
# We need to keep track of some additional input arguments for post-processing so need to forward these on after running generation
|
391 |
+
input_features = model_inputs.pop("input_features")
|
392 |
+
input_batch_size = input_features.shape[0]
|
393 |
+
|
394 |
+
if input_batch_size != batch_size:
|
395 |
+
padding = np.zeros([batch_size - input_batch_size, *input_features.shape[1:]], input_features.dtype)
|
396 |
+
input_features = np.concatenate([input_features, padding])
|
397 |
+
|
398 |
+
pred_ids = self.generate(
|
399 |
+
input_features,
|
400 |
+
language=language,
|
401 |
+
task=task,
|
402 |
+
return_timestamps=return_timestamps,
|
403 |
+
num_beams=num_beams,
|
404 |
+
length_penalty=length_penalty,
|
405 |
+
do_sample=do_sample,
|
406 |
+
top_k=top_k,
|
407 |
+
temperature=temperature,
|
408 |
+
)[:input_batch_size]
|
409 |
+
|
410 |
+
# tokenizer's decode method expects an extra dim - we insert it here for convenience
|
411 |
+
out = {"tokens": pred_ids[:, None, :]}
|
412 |
+
|
413 |
+
stride = model_inputs.pop("stride", None)
|
414 |
+
if stride is not None:
|
415 |
+
out["stride"] = stride
|
416 |
+
|
417 |
+
return out
|
418 |
+
|
419 |
+
def __call__(
|
420 |
+
self,
|
421 |
+
inputs,
|
422 |
+
chunk_length_s=30.0,
|
423 |
+
stride_length_s=None,
|
424 |
+
batch_size=None,
|
425 |
+
language=None,
|
426 |
+
task=None,
|
427 |
+
return_timestamps=None,
|
428 |
+
num_beams=1,
|
429 |
+
length_penalty=1.0,
|
430 |
+
do_sample=False,
|
431 |
+
top_k=50,
|
432 |
+
temperature=1.0,
|
433 |
+
):
|
434 |
+
"""
|
435 |
+
Transcribe an audio input sequence to a text transcription, optionally with timestamps.
|
436 |
+
|
437 |
+
Args:
|
438 |
+
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
|
439 |
+
The inputs is either:
|
440 |
+
- `str` that is the filename of the audio file, the file will be read at the correct sampling rate
|
441 |
+
to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system.
|
442 |
+
- `bytes` is the byte content of an audio file and is interpreted by *ffmpeg* in the
|
443 |
+
same way.
|
444 |
+
- (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
|
445 |
+
Raw audio assumed to be at the correct sampling rate (16kHz). Note that no further sampling
|
446 |
+
rate check will be done.
|
447 |
+
- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
|
448 |
+
pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "array":
|
449 |
+
np.array}`. Optionally an additional argument `"stride": (left: int, right: int)` can be used to
|
450 |
+
ask the pipeline to treat the first `left` samples and last `right` samples to be ignored in
|
451 |
+
decoding (but used at inference to provide more context to the model). In general, this additional
|
452 |
+
stride argument is not required.
|
453 |
+
chunk_length_s (`float`, *optional*, defaults to 30.0):
|
454 |
+
The input length for each chunk. If `chunk_length_s = 0` then chunking is disabled. By default, the chunk
|
455 |
+
length is set 30.0s, equal to Whisper's context window.
|
456 |
+
stride_length_s (`float`, *optional*, defaults to `chunk_length_s / 6`):
|
457 |
+
The length of stride on the left and right of each chunk. Used only with `chunk_length_s > 0`. This enables
|
458 |
+
the model to *see* more context and infer letters better than without this context but the pipeline
|
459 |
+
discards the stride bits at the end to make the final reconstitution as perfect as possible.
|
460 |
+
|
461 |
+
<Tip>
|
462 |
+
|
463 |
+
For more information on how to effectively use `stride_length_s`, refer to the [ASR chunking
|
464 |
+
blog post](https://huggingface.co/blog/asr-chunking).
|
465 |
+
|
466 |
+
</Tip>
|
467 |
+
batch_size (`int`, *optional*, defaults to the minimum per-device batch size, i.e. `jax.local_device_count()`):
|
468 |
+
The batch size to be used in chunking transcription. Beneficial for transcribing long audio files. Passing
|
469 |
+
a batch size in the `__call__` method will supersede any batch size passed to the `__init__`.
|
470 |
+
task (`str`, *optional*):
|
471 |
+
Task to use for generation, either `"transcribe"` or `"translate"`. Defaults to `"transcribe"`.
|
472 |
+
language (`str`, *optional*):
|
473 |
+
Language token to use for generation, can be either in the form of `"<|en|>"`, `"en"` or `"english"`.
|
474 |
+
Defaults to `None`, meaning the language is automatically inferred from the audio input.
|
475 |
+
return_timestamps (*optional*, `bool`):
|
476 |
+
Whether to return timestamps in the prediction. Defaults to False. If set to true, the pipeline
|
477 |
+
will return two keys in the output dictionary: `"text"` containing the text transcription, and `"chunks"`
|
478 |
+
containing the transcription segments chunked by their utterance-level timestamps.
|
479 |
+
length_penalty (*optional*, `float`):
|
480 |
+
Exponential penalty to the length that is used with beam-based generation. It is applied as an
|
481 |
+
exponent to the sequence length, which in turn is used to divide the score of the sequence. Since
|
482 |
+
the score is the log likelihood of the sequence (i.e. negative), length_penalty > 1.0 promotes
|
483 |
+
longer sequences, while length_penalty < 1.0 encourages shorter sequences.
|
484 |
+
do_sample (*optional*, `bool`):
|
485 |
+
Whether or not to use sampling ; use greedy decoding otherwise.
|
486 |
+
top_k (*optional*, `int`):
|
487 |
+
The number of the highest probability vocabulary tokens to keep for top-k-filtering.
|
488 |
+
temperature (*optional*, `float`):
|
489 |
+
The value used to modulate the next token probabilities if sampling.
|
490 |
+
|
491 |
+
Return:
|
492 |
+
`Dict`: A dictionary with the following keys:
|
493 |
+
- **text** (`str` ) -- The recognised text.
|
494 |
+
- **chunks** (*optional(, `List[Dict]`)
|
495 |
+
When using `return_timestamps`, the `chunks` will become a list containing all the various text
|
496 |
+
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text":
|
497 |
+
"there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
|
498 |
+
`"".join(chunk["text"] for chunk in output["chunks"])`.
|
499 |
+
"""
|
500 |
+
batch_size = batch_size if batch_size is not None else self.batch_size
|
501 |
+
if batch_size % self.min_batch_size != 0:
|
502 |
+
raise ValueError(
|
503 |
+
f"Batch size must be a multiple of the number of JAX devices, but got batch size {batch_size} and num devices {self.min_batch_size}."
|
504 |
+
)
|
505 |
+
|
506 |
+
dataloader = self.preprocess_batch(
|
507 |
+
inputs, chunk_length_s=chunk_length_s, stride_length_s=stride_length_s, batch_size=batch_size
|
508 |
+
)
|
509 |
+
model_outputs = []
|
510 |
+
# iterate over our chunked audio samples
|
511 |
+
for batch in dataloader:
|
512 |
+
model_outputs.append(
|
513 |
+
self.forward(
|
514 |
+
batch,
|
515 |
+
batch_size=batch_size,
|
516 |
+
language=language,
|
517 |
+
task=task,
|
518 |
+
return_timestamps=return_timestamps,
|
519 |
+
num_beams=num_beams,
|
520 |
+
length_penalty=length_penalty,
|
521 |
+
do_sample=do_sample,
|
522 |
+
top_k=top_k,
|
523 |
+
temperature=temperature,
|
524 |
+
)
|
525 |
+
)
|
526 |
+
post_processed = self.postprocess(model_outputs, return_timestamps=return_timestamps)
|
527 |
+
return post_processed
|
flax/distil_whisper/train_state.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Mapping, MutableMapping, Optional, Tuple
|
2 |
+
|
3 |
+
import flax.core
|
4 |
+
import flax.serialization
|
5 |
+
import flax.struct
|
6 |
+
import jax.numpy as jnp
|
7 |
+
from flax import traverse_util
|
8 |
+
from flax.core import scope as flax_scope
|
9 |
+
from flax.linen import partitioning as flax_partitioning
|
10 |
+
|
11 |
+
|
12 |
+
EMPTY_DICT = flax.core.freeze({})
|
13 |
+
FrozenDict = flax_scope.FrozenDict
|
14 |
+
FrozenVariableDict = flax_scope.FrozenVariableDict
|
15 |
+
MutableVariableDict = flax_scope.MutableVariableDict
|
16 |
+
VariableDict = flax_scope.VariableDict
|
17 |
+
|
18 |
+
|
19 |
+
def _validate_params_axes(params_axes, params):
|
20 |
+
axis_names = flax_partitioning.get_axis_names(params_axes)
|
21 |
+
missing_params_axes = set(traverse_util.flatten_dict(params, sep="/")) - set(
|
22 |
+
traverse_util.flatten_dict(axis_names, sep="/")
|
23 |
+
)
|
24 |
+
if missing_params_axes:
|
25 |
+
raise ValueError(f"Missing axis names for parameters: {missing_params_axes}")
|
26 |
+
|
27 |
+
|
28 |
+
def _split_variables_and_axes(
|
29 |
+
variables_and_axes: FrozenVariableDict,
|
30 |
+
) -> Tuple[FrozenVariableDict, FrozenVariableDict]:
|
31 |
+
"""Splits `variables_and_axes` into two separate dicts with the same keys."""
|
32 |
+
# For each `key`, `key_axes` (if any) are its axes in `variables_and_axes`.
|
33 |
+
variables = {}
|
34 |
+
axes = {}
|
35 |
+
for k, v in variables_and_axes.items():
|
36 |
+
if k.endswith("_axes"):
|
37 |
+
axes[k[:-5]] = v # k without "_axes".
|
38 |
+
_validate_params_axes(v, variables_and_axes[k[:-5]]) # k without "_axes".
|
39 |
+
else:
|
40 |
+
variables[k] = v
|
41 |
+
return flax.core.freeze(variables), flax.core.freeze(axes)
|
42 |
+
|
43 |
+
|
44 |
+
class InferenceState(flax.struct.PyTreeNode):
|
45 |
+
"""State compatible with FlaxOptimTrainState without optimizer state."""
|
46 |
+
|
47 |
+
step: jnp.ndarray
|
48 |
+
params: flax_scope.FrozenVariableDict
|
49 |
+
params_axes: Optional[flax_scope.FrozenVariableDict] = None
|
50 |
+
flax_mutables: flax_scope.FrozenDict = EMPTY_DICT
|
51 |
+
flax_mutables_axes: Optional[flax_scope.FrozenVariableDict] = None
|
52 |
+
|
53 |
+
@classmethod
|
54 |
+
def create(cls, model_variables: FrozenVariableDict) -> "InferenceState":
|
55 |
+
other_variables, params = model_variables.pop("params")
|
56 |
+
if "params_axes" in other_variables:
|
57 |
+
other_variables, params_axes = other_variables.pop("params_axes")
|
58 |
+
_validate_params_axes(params_axes, params)
|
59 |
+
else:
|
60 |
+
params_axes = None
|
61 |
+
|
62 |
+
# Split other_variables into mutables and their corresponding axes.
|
63 |
+
flax_mutables, flax_mutables_axes = _split_variables_and_axes(other_variables)
|
64 |
+
flax_mutables_axes = flax_mutables_axes or None
|
65 |
+
return InferenceState(
|
66 |
+
step=jnp.array(0),
|
67 |
+
params=params,
|
68 |
+
params_axes=params_axes,
|
69 |
+
flax_mutables=flax_mutables,
|
70 |
+
flax_mutables_axes=flax_mutables_axes,
|
71 |
+
)
|
72 |
+
|
73 |
+
@property
|
74 |
+
def param_states(self) -> FrozenVariableDict:
|
75 |
+
"""The optimizer states of the parameters as a PyTree."""
|
76 |
+
raise NotImplementedError("InferenceState has no optimizer states.")
|
77 |
+
|
78 |
+
def apply_gradient(self, *args, **kwargs) -> "InferenceState":
|
79 |
+
raise NotImplementedError("InferenceState does not support `apply_gradient`.")
|
80 |
+
|
81 |
+
def state_dict(self) -> MutableMapping[str, Any]:
|
82 |
+
state_dict = {
|
83 |
+
"target": flax.core.unfreeze(self.params),
|
84 |
+
"state": {"step": self.step},
|
85 |
+
}
|
86 |
+
if self.flax_mutables:
|
87 |
+
state_dict["flax_mutables"] = flax.core.unfreeze(self.flax_mutables)
|
88 |
+
return state_dict
|
89 |
+
|
90 |
+
def replace_step(self, step: jnp.ndarray) -> "InferenceState":
|
91 |
+
return self.replace(step=step)
|
92 |
+
|
93 |
+
def replace_params(self, params: FrozenVariableDict) -> "InferenceState":
|
94 |
+
return self.replace(params=params)
|
95 |
+
|
96 |
+
def replace_flax_mutables(self, flax_mutables: FrozenDict) -> "InferenceState":
|
97 |
+
return self.replace(flax_mutables=flax_mutables)
|
98 |
+
|
99 |
+
def restore_state(self, state_dict: Mapping[str, Any]) -> "InferenceState":
|
100 |
+
return self.replace(
|
101 |
+
params=flax.core.freeze(state_dict["target"]),
|
102 |
+
step=state_dict["state"]["step"],
|
103 |
+
flax_mutables=(
|
104 |
+
flax.core.freeze(state_dict["flax_mutables"]) if "flax_mutables" in state_dict else EMPTY_DICT
|
105 |
+
),
|
106 |
+
)
|
107 |
+
|
108 |
+
def as_logical_axes(self) -> "InferenceState":
|
109 |
+
# Set step to None so that when the logical axes are processed by the
|
110 |
+
# flax.partitioning.logical_to_mesh_axes function, it will be skipped
|
111 |
+
# because jax.tree_map will short circut and never call the function on the
|
112 |
+
# step.
|
113 |
+
flax_mutables_axes = self.flax_mutables_axes or EMPTY_DICT
|
114 |
+
return InferenceState(
|
115 |
+
step=None,
|
116 |
+
params=flax_partitioning.get_axis_names(self.params_axes),
|
117 |
+
flax_mutables=flax_partitioning.get_axis_names(flax_mutables_axes),
|
118 |
+
)
|
flax/distillation_scripts/run_32_2_pt.sh
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
accelerate launch --multi_gpu --mixed_precision=bf16 --num_processes=2 run_distillation_pt.py \
|
4 |
+
--model_name_or_path distil-whisper/large-32-2 \
|
5 |
+
--teacher_model_name_or_path openai/whisper-large-v2 \
|
6 |
+
--train_dataset_config_name all+all+all+l \
|
7 |
+
--train_dataset_samples 2.9+10.4+14.9+226.6 \
|
8 |
+
--train_dataset_name librispeech_asr+librispeech_asr+librispeech_asr+gigaspeech-l \
|
9 |
+
--train_split_name train.clean.100+train.clean.360+train.other.500+train \
|
10 |
+
--eval_dataset_name librispeech_asr+librispeech_asr+gigaspeech-l \
|
11 |
+
--eval_dataset_config_name all+all+l \
|
12 |
+
--eval_split_name validation.clean+validation.other+validation \
|
13 |
+
--eval_text_column_name text+text+text \
|
14 |
+
--eval_steps 2500 \
|
15 |
+
--save_steps 2500 \
|
16 |
+
--warmup_steps 50 \
|
17 |
+
--learning_rate 0.0001 \
|
18 |
+
--lr_scheduler_type constant_with_warmup \
|
19 |
+
--logging_steps 25 \
|
20 |
+
--save_total_limit 1 \
|
21 |
+
--max_steps 10000 \
|
22 |
+
--wer_threshold 10 \
|
23 |
+
--per_device_train_batch_size 64 \
|
24 |
+
--gradient_accumulation_steps 2 \
|
25 |
+
--per_device_eval_batch_size 64 \
|
26 |
+
--dataloader_num_workers 16 \
|
27 |
+
--cache_dir /fsx/sanchit/cache \
|
28 |
+
--dataset_cache_dir /fsx/sanchit/cache \
|
29 |
+
--dtype bfloat16 \
|
30 |
+
--output_dir ./ \
|
31 |
+
--wandb_project distil-whisper-training \
|
32 |
+
--do_train \
|
33 |
+
--do_eval \
|
34 |
+
--gradient_checkpointing \
|
35 |
+
--overwrite_output_dir \
|
36 |
+
--predict_with_generate \
|
37 |
+
--freeze_encoder \
|
38 |
+
--streaming
|
flax/distillation_scripts/run_bs_sweep.yaml
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
command:
|
2 |
+
- python3
|
3 |
+
- ${program}
|
4 |
+
- --do_train
|
5 |
+
- --use_scan
|
6 |
+
- --gradient_checkpointing
|
7 |
+
- --overwrite_output_dir
|
8 |
+
- --predict_with_generate
|
9 |
+
- --freeze_encoder
|
10 |
+
- --streaming
|
11 |
+
- --use_auth_token
|
12 |
+
- --compilation_cache
|
13 |
+
- ${args}
|
14 |
+
method: grid
|
15 |
+
metric:
|
16 |
+
goal: minimize
|
17 |
+
name: train/loss
|
18 |
+
parameters:
|
19 |
+
model_name_or_path:
|
20 |
+
value: distil-whisper/large-32-2
|
21 |
+
teacher_model_name_or_path:
|
22 |
+
value: openai/whisper-large-v2
|
23 |
+
train_dataset_name:
|
24 |
+
value: librispeech_asr
|
25 |
+
train_dataset_config_name:
|
26 |
+
value: all
|
27 |
+
train_split_name:
|
28 |
+
value: train.other.500
|
29 |
+
train_dataset_samples:
|
30 |
+
value: 100
|
31 |
+
cache_dir:
|
32 |
+
value: /fsx/sanchitgandhi/cache
|
33 |
+
dataset_cache_dir:
|
34 |
+
value: /fsx/sanchitgandhi/cache
|
35 |
+
output_dir:
|
36 |
+
value: ./
|
37 |
+
per_device_train_batch_size:
|
38 |
+
values:
|
39 |
+
- 128
|
40 |
+
- 256
|
41 |
+
- 512
|
42 |
+
precision:
|
43 |
+
values:
|
44 |
+
- "full_mixed"
|
45 |
+
- "half_mixed"
|
46 |
+
dtype:
|
47 |
+
value: bfloat16
|
48 |
+
do_eval:
|
49 |
+
value: false
|
50 |
+
learning_rate:
|
51 |
+
value: 3e-4
|
52 |
+
lr_scheduler_type:
|
53 |
+
value: constant_with_warmup
|
54 |
+
warmup_steps:
|
55 |
+
value: 30
|
56 |
+
max_steps:
|
57 |
+
value: 30
|
58 |
+
save_steps:
|
59 |
+
value: 51 # don't save checkpoints during sweep
|
60 |
+
dataloader_num_workers:
|
61 |
+
value: 48
|
62 |
+
logging_steps:
|
63 |
+
value: 5
|
64 |
+
wer_threshold:
|
65 |
+
value: 100
|
66 |
+
program: run_distillation.py
|
67 |
+
project: distil-whisper-sweeps
|
flax/distillation_scripts/run_dataset_sweep.yaml
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
command:
|
2 |
+
- python3
|
3 |
+
- ${program}
|
4 |
+
- --do_train
|
5 |
+
- --do_eval
|
6 |
+
- --use_scan
|
7 |
+
- --gradient_checkpointing
|
8 |
+
- --overwrite_output_dir
|
9 |
+
- --predict_with_generate
|
10 |
+
- --freeze_encoder
|
11 |
+
- --streaming
|
12 |
+
- --use_auth_token
|
13 |
+
- ${args}
|
14 |
+
method: grid
|
15 |
+
metric:
|
16 |
+
goal: minimize
|
17 |
+
name: gigaspeech-l/validation/wer
|
18 |
+
parameters:
|
19 |
+
model_name_or_path:
|
20 |
+
value: distil-whisper/large-32-2
|
21 |
+
teacher_model_name_or_path:
|
22 |
+
value: openai/whisper-large-v2
|
23 |
+
max_train_samples:
|
24 |
+
values:
|
25 |
+
- 109876
|
26 |
+
- 219752
|
27 |
+
- 439504
|
28 |
+
- 879008
|
29 |
+
- 1758015
|
30 |
+
- 3516030
|
31 |
+
- 7032061
|
32 |
+
train_dataset_name:
|
33 |
+
value: librispeech_asr-timestamped+librispeech_asr-timestamped+librispeech_asr-timestamped+common_voice_13_0-timestamped+voxpopuli-timestamped+ami-ihm-timestamped+ami-sdm-timestamped+peoples_speech-clean-timestamped+tedlium-timestamped+switchboard-data+gigaspeech-l-timestamped+librispeech_asr-prompted+librispeech_asr-prompted+librispeech_asr-prompted+tedlium-prompted
|
34 |
+
train_dataset_config_name:
|
35 |
+
value: all+all+all+en+en+ihm+sdm+clean+release3+all+l+all+all+all+release3
|
36 |
+
train_split_name:
|
37 |
+
value: train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train.clean.100+train.clean.360+train.other.500+train
|
38 |
+
train_dataset_samples:
|
39 |
+
value: 2.9+10.4+14.9+89+18.2+10.9+10.9+288+26.8+371.2+226.6+2.9+10.4+14.9+26.8
|
40 |
+
eval_dataset_name:
|
41 |
+
value: librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech+chime4+google/fleurs
|
42 |
+
eval_dataset_config_name:
|
43 |
+
value: all+all+en+en+ihm+sdm+clean+release3+all+l+L+1-channel+en_us
|
44 |
+
eval_split_name:
|
45 |
+
value: validation.clean+validation.other+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation
|
46 |
+
eval_text_column_name:
|
47 |
+
value: text+text+text+text+text+text+text+text+text+text+text+text+transcription
|
48 |
+
cache_dir:
|
49 |
+
value: /home/sanchitgandhi/.cache
|
50 |
+
dataset_cache_dir:
|
51 |
+
value: /home/sanchitgandhi/.cache
|
52 |
+
output_dir:
|
53 |
+
value: ./
|
54 |
+
per_device_train_batch_size:
|
55 |
+
value: 64
|
56 |
+
per_device_eval_batch_size:
|
57 |
+
value: 64
|
58 |
+
dtype:
|
59 |
+
value: bfloat16
|
60 |
+
learning_rate:
|
61 |
+
value: 1e-4
|
62 |
+
lr_scheduler_type:
|
63 |
+
value: constant_with_warmup
|
64 |
+
warmup_steps:
|
65 |
+
value: 50
|
66 |
+
max_steps:
|
67 |
+
value: 10000
|
68 |
+
save_steps:
|
69 |
+
value: 10001 # don't save checkpoints during sweep
|
70 |
+
dataloader_num_workers:
|
71 |
+
value: 48
|
72 |
+
logging_steps:
|
73 |
+
value: 25
|
74 |
+
wer_threshold:
|
75 |
+
value: 10
|
76 |
+
program: run_distillation.py
|
77 |
+
project: distil-whisper-sweeps
|
flax/distillation_scripts/run_decoder_sweep.yaml
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
command:
|
2 |
+
- python3
|
3 |
+
- ${program}
|
4 |
+
- --do_train
|
5 |
+
- --do_eval
|
6 |
+
- --use_scan
|
7 |
+
- --gradient_checkpointing
|
8 |
+
- --overwrite_output_dir
|
9 |
+
- --predict_with_generate
|
10 |
+
- --freeze_encoder
|
11 |
+
- --streaming
|
12 |
+
- --use_auth_token
|
13 |
+
- ${args}
|
14 |
+
method: grid
|
15 |
+
metric:
|
16 |
+
goal: minimize
|
17 |
+
name: gigaspeech-l/validation/wer
|
18 |
+
parameters:
|
19 |
+
model_name_or_path:
|
20 |
+
values:
|
21 |
+
- distil-whisper/large-32-16
|
22 |
+
- distil-whisper/large-32-8
|
23 |
+
- distil-whisper/large-32-4
|
24 |
+
- distil-whisper/large-32-2
|
25 |
+
teacher_model_name_or_path:
|
26 |
+
value: openai/whisper-large-v2
|
27 |
+
train_dataset_name:
|
28 |
+
value: librispeech_asr-timestamped+librispeech_asr-timestamped+librispeech_asr-timestamped+common_voice_13_0-timestamped+voxpopuli-timestamped+ami-ihm-timestamped+ami-sdm-timestamped+peoples_speech-clean-timestamped+tedlium-timestamped+switchboard-data+gigaspeech-l-timestamped+librispeech_asr-prompted+librispeech_asr-prompted+librispeech_asr-prompted+tedlium-prompted
|
29 |
+
train_dataset_config_name:
|
30 |
+
value: all+all+all+en+en+ihm+sdm+clean+release3+all+l+all+all+all+release3
|
31 |
+
train_split_name:
|
32 |
+
value: train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train.clean.100+train.clean.360+train.other.500+train
|
33 |
+
train_dataset_samples:
|
34 |
+
value: 2.9+10.4+14.9+89+18.2+10.9+10.9+288+26.8+371.2+226.6+2.9+10.4+14.9+26.8
|
35 |
+
eval_dataset_name:
|
36 |
+
value: librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech+chime4+google/fleurs
|
37 |
+
eval_dataset_config_name:
|
38 |
+
value: all+all+en+en+ihm+sdm+clean+release3+all+l+L+1-channel+en_us
|
39 |
+
eval_split_name:
|
40 |
+
value: validation.clean+validation.other+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation
|
41 |
+
eval_text_column_name:
|
42 |
+
value: text+text+text+text+text+text+text+text+text+text+text+text+transcription
|
43 |
+
cache_dir:
|
44 |
+
value: /home/sanchitgandhi/.cache
|
45 |
+
dataset_cache_dir:
|
46 |
+
value: /home/sanchitgandhi/.cache
|
47 |
+
output_dir:
|
48 |
+
value: ./
|
49 |
+
per_device_train_batch_size:
|
50 |
+
value: 64
|
51 |
+
per_device_eval_batch_size:
|
52 |
+
value: 64
|
53 |
+
dtype:
|
54 |
+
value: bfloat16
|
55 |
+
learning_rate:
|
56 |
+
value: 1e-4
|
57 |
+
lr_scheduler_type:
|
58 |
+
value: constant_with_warmup
|
59 |
+
warmup_steps:
|
60 |
+
value: 50
|
61 |
+
max_steps:
|
62 |
+
value: 10000
|
63 |
+
save_steps:
|
64 |
+
value: 10001 # don't save checkpoints during sweep
|
65 |
+
dataloader_num_workers:
|
66 |
+
value: 48
|
67 |
+
logging_steps:
|
68 |
+
value: 25
|
69 |
+
wer_threshold:
|
70 |
+
value: 10
|
71 |
+
program: run_distillation.py
|
72 |
+
project: distil-whisper-sweeps
|
flax/distillation_scripts/run_distillation_12_2_timestamped.sh
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python3 run_distillation.py \
|
4 |
+
--model_name_or_path "distil-whisper/small-12-2" \
|
5 |
+
--teacher_model_name_or_path "openai/whisper-medium.en" \
|
6 |
+
--train_dataset_config_name "all+all+all+en+en+ihm+sdm+clean+release3+all+l+all+all+all+release3" \
|
7 |
+
--train_dataset_samples "2.9+10.4+14.9+89+18.2+10.9+10.9+288+26.8+371.2+226.6+2.9+10.4+14.9+26.8" \
|
8 |
+
--train_dataset_name "librispeech_asr-timestamped+librispeech_asr-timestamped+librispeech_asr-timestamped+common_voice_13_0-timestamped+voxpopuli-timestamped+ami-ihm-timestamped+ami-sdm-timestamped+peoples_speech-clean-timestamped+tedlium-timestamped+switchboard-data+gigaspeech-l-timestamped+librispeech_asr-prompted+librispeech_asr-prompted+librispeech_asr-prompted+tedlium-prompted" \
|
9 |
+
--train_split_name "train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train.clean.100+train.clean.360+train.other.500+train" \
|
10 |
+
--eval_dataset_name "distil-whisper/gigaspeech-l+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset" \
|
11 |
+
--eval_dataset_config_name "l+librispeech+librispeech+common_voice+common_voice+voxpopuli+voxpopuli+tedlium+tedlium+spgispeech+spgispeech+ami+ami" \
|
12 |
+
--eval_split_name "validation+clean+other+clean+other+clean+other+clean+other+clean+other+clean+other" \
|
13 |
+
--eval_text_column_name "text+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript" \
|
14 |
+
--eval_steps 5000 \
|
15 |
+
--save_steps 5000 \
|
16 |
+
--warmup_steps 500 \
|
17 |
+
--learning_rate 0.0001 \
|
18 |
+
--logging_steps 25 \
|
19 |
+
--save_total_limit 1 \
|
20 |
+
--max_steps 80000 \
|
21 |
+
--wer_threshold 10 \
|
22 |
+
--per_device_train_batch_size 64 \
|
23 |
+
--per_device_eval_batch_size 64 \
|
24 |
+
--dtype "bfloat16" \
|
25 |
+
--dataloader_num_workers 16 \
|
26 |
+
--cache_dir "/home/sanchitgandhi/.cache" \
|
27 |
+
--dataset_cache_dir "/home/sanchitgandhi/.cache" \
|
28 |
+
--output_dir "./" \
|
29 |
+
--timestamp_probability 0.2 \
|
30 |
+
--wandb_name "small-12-2-tpu-timestamped-prob-0.2" \
|
31 |
+
--wandb_dir "/home/sanchitgandhi/.cache" \
|
32 |
+
--wandb_project "distil-whisper" \
|
33 |
+
--do_train \
|
34 |
+
--do_eval \
|
35 |
+
--use_scan \
|
36 |
+
--gradient_checkpointing \
|
37 |
+
--overwrite_output_dir \
|
38 |
+
--predict_with_generate \
|
39 |
+
--freeze_encoder \
|
40 |
+
--streaming \
|
41 |
+
--use_auth_token \
|
42 |
+
--push_to_hub
|
flax/distillation_scripts/run_distillation_15s_context.sh
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python3 run_distillation.py \
|
4 |
+
--model_name_or_path "distil-whisper/large-32-2-15s-context" \
|
5 |
+
--teacher_model_name_or_path "openai/whisper-large-v2" \
|
6 |
+
--feature_extractor_name "openai/whisper-large-v2" \
|
7 |
+
--train_dataset_config_name "all+all+all+en+en+ihm+sdm+clean+release3+all+l+L" \
|
8 |
+
--train_dataset_samples "100+360+500+2300+450+90+90+12000+450+3600+2500+5000" \
|
9 |
+
--train_dataset_name "librispeech_asr+librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech" \
|
10 |
+
--train_split_name "train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train" \
|
11 |
+
--eval_dataset_name "distil-whisper/gigaspeech-l+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset" \
|
12 |
+
--eval_dataset_config_name "l+librispeech+librispeech+common_voice+common_voice+voxpopuli+voxpopuli+tedlium+tedlium+spgispeech+spgispeech+ami+ami" \
|
13 |
+
--eval_split_name "validation+clean+other+clean+other+clean+other+clean+other+clean+other+clean+other" \
|
14 |
+
--eval_text_column_name "text+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript" \
|
15 |
+
--eval_steps 5000 \
|
16 |
+
--save_steps 5000 \
|
17 |
+
--warmup_steps 500 \
|
18 |
+
--learning_rate 0.0001 \
|
19 |
+
--lr_scheduler_type "linear" \
|
20 |
+
--logging_steps 25 \
|
21 |
+
--save_total_limit 1 \
|
22 |
+
--max_steps 80000 \
|
23 |
+
--wer_threshold 10 \
|
24 |
+
--per_device_train_batch_size 64 \
|
25 |
+
--per_device_eval_batch_size 64 \
|
26 |
+
--max_duration_in_seconds 15 \
|
27 |
+
--dataloader_num_workers 16 \
|
28 |
+
--cache_dir "/home/sanchitgandhi/.cache" \
|
29 |
+
--dataset_cache_dir "/home/sanchitgandhi/.cache" \
|
30 |
+
--dtype "bfloat16" \
|
31 |
+
--output_dir "./" \
|
32 |
+
--wandb_name "large-32-2-ts-28k-wer-10-context-15s" \
|
33 |
+
--wandb_dir "/home/sanchitgandhi/.cache" \
|
34 |
+
--wandb_project "distil-whisper" \
|
35 |
+
--do_train \
|
36 |
+
--do_eval \
|
37 |
+
--use_scan \
|
38 |
+
--gradient_checkpointing \
|
39 |
+
--overwrite_output_dir \
|
40 |
+
--predict_with_generate \
|
41 |
+
--streaming \
|
42 |
+
--use_auth_token \
|
43 |
+
--push_to_hub
|
flax/distillation_scripts/run_distillation_16_2.sh
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python3 run_distillation.py \
|
4 |
+
--model_name_or_path "distil-whisper/large-16-2" \
|
5 |
+
--teacher_model_name_or_path "openai/whisper-large-v2" \
|
6 |
+
--train_dataset_config_name "all+all+all+en+en+ihm+sdm+clean+release3+all+l+L" \
|
7 |
+
--train_dataset_samples "100+360+500+2300+450+90+90+12000+450+3600+2500+5000" \
|
8 |
+
--train_dataset_name "librispeech_asr+librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech" \
|
9 |
+
--train_split_name "train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train" \
|
10 |
+
--eval_dataset_name "distil-whisper/gigaspeech-l+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset" \
|
11 |
+
--eval_dataset_config_name "l+librispeech+librispeech+common_voice+common_voice+voxpopuli+voxpopuli+tedlium+tedlium+spgispeech+spgispeech+ami+ami" \
|
12 |
+
--eval_split_name "validation+clean+other+clean+other+clean+other+clean+other+clean+other+clean+other" \
|
13 |
+
--eval_text_column_name "text+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript" \
|
14 |
+
--eval_steps 5000 \
|
15 |
+
--save_steps 5000 \
|
16 |
+
--warmup_steps 500 \
|
17 |
+
--learning_rate 0.0001 \
|
18 |
+
--lr_scheduler_type "linear" \
|
19 |
+
--logging_steps 25 \
|
20 |
+
--save_total_limit 1 \
|
21 |
+
--max_steps 80000 \
|
22 |
+
--wer_threshold 10 \
|
23 |
+
--per_device_eval_batch_size 64 \
|
24 |
+
--per_device_train_batch_size 64 \
|
25 |
+
--dataloader_num_workers 16 \
|
26 |
+
--cache_dir "/home/sanchitgandhi/.cache" \
|
27 |
+
--dataset_cache_dir "/home/sanchitgandhi/.cache" \
|
28 |
+
--dtype "bfloat16" \
|
29 |
+
--output_dir "./" \
|
30 |
+
--wandb_name "large-16-2-ts-28k-wer-10" \
|
31 |
+
--wandb_dir "/home/sanchitgandhi/.cache" \
|
32 |
+
--wandb_project "distil-whisper" \
|
33 |
+
--do_train \
|
34 |
+
--do_eval \
|
35 |
+
--use_scan \
|
36 |
+
--gradient_checkpointing \
|
37 |
+
--overwrite_output_dir \
|
38 |
+
--predict_with_generate \
|
39 |
+
--streaming \
|
40 |
+
--use_auth_token \
|
41 |
+
--push_to_hub
|
flax/distillation_scripts/run_distillation_24_2.sh
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python3 run_distillation.py \
|
4 |
+
--model_name_or_path "distil-whisper/medium-24-2" \
|
5 |
+
--teacher_model_name_or_path "openai/whisper-medium.en" \
|
6 |
+
--train_dataset_config_name "all+all+all+en+en+ihm+sdm+clean+release3+all+l+L" \
|
7 |
+
--train_dataset_samples "100+360+500+2300+450+90+90+12000+450+3600+2500+5000" \
|
8 |
+
--train_dataset_name "librispeech_asr+librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech" \
|
9 |
+
--train_split_name "train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train" \
|
10 |
+
--eval_dataset_name "distil-whisper/gigaspeech-l+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset" \
|
11 |
+
--eval_dataset_config_name "l+librispeech+librispeech+common_voice+common_voice+voxpopuli+voxpopuli+tedlium+tedlium+spgispeech+spgispeech+ami+ami" \
|
12 |
+
--eval_split_name "validation+clean+other+clean+other+clean+other+clean+other+clean+other+clean+other" \
|
13 |
+
--eval_text_column_name "text+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript" \
|
14 |
+
--eval_steps 5000 \
|
15 |
+
--save_steps 5000 \
|
16 |
+
--warmup_steps 500 \
|
17 |
+
--learning_rate 0.0001 \
|
18 |
+
--lr_scheduler_type "linear" \
|
19 |
+
--logging_steps 25 \
|
20 |
+
--save_total_limit 1 \
|
21 |
+
--max_steps 80000 \
|
22 |
+
--wer_threshold 10 \
|
23 |
+
--per_device_eval_batch_size 64 \
|
24 |
+
--per_device_train_batch_size 64 \
|
25 |
+
--dataloader_num_workers 16 \
|
26 |
+
--cache_dir "/home/sanchitgandhi/.cache" \
|
27 |
+
--dataset_cache_dir "/home/sanchitgandhi/.cache" \
|
28 |
+
--dtype "bfloat16" \
|
29 |
+
--output_dir "./" \
|
30 |
+
--wandb_name "medium-24-2-ts-freeze-28k-wer-10" \
|
31 |
+
--wandb_dir "/home/sanchitgandhi/.cache" \
|
32 |
+
--wandb_project "distil-whisper" \
|
33 |
+
--do_train \
|
34 |
+
--do_eval \
|
35 |
+
--use_scan \
|
36 |
+
--gradient_checkpointing \
|
37 |
+
--overwrite_output_dir \
|
38 |
+
--predict_with_generate \
|
39 |
+
--streaming \
|
40 |
+
--freeze_encoder \
|
41 |
+
--use_auth_token \
|
42 |
+
--push_to_hub
|
flax/distillation_scripts/run_distillation_24_2_timestamped.sh
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python3 run_distillation.py \
|
4 |
+
--model_name_or_path "distil-whisper/medium-24-2" \
|
5 |
+
--teacher_model_name_or_path "openai/whisper-medium.en" \
|
6 |
+
--train_dataset_config_name "all+all+all+en+en+ihm+sdm+clean+release3+all+l+all+all+all+release3" \
|
7 |
+
--train_dataset_samples "2.9+10.4+14.9+89+18.2+10.9+10.9+288+26.8+371.2+226.6+2.9+10.4+14.9+26.8" \
|
8 |
+
--train_dataset_name "librispeech_asr-timestamped+librispeech_asr-timestamped+librispeech_asr-timestamped+common_voice_13_0-timestamped+voxpopuli-timestamped+ami-ihm-timestamped+ami-sdm-timestamped+peoples_speech-clean-timestamped+tedlium-timestamped+switchboard-data+gigaspeech-l-timestamped+librispeech_asr-prompted+librispeech_asr-prompted+librispeech_asr-prompted+tedlium-prompted" \
|
9 |
+
--train_split_name "train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train.clean.100+train.clean.360+train.other.500+train" \
|
10 |
+
--eval_dataset_name "distil-whisper/gigaspeech-l+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset" \
|
11 |
+
--eval_dataset_config_name "l+librispeech+librispeech+common_voice+common_voice+voxpopuli+voxpopuli+tedlium+tedlium+spgispeech+spgispeech+ami+ami" \
|
12 |
+
--eval_split_name "validation+clean+other+clean+other+clean+other+clean+other+clean+other+clean+other" \
|
13 |
+
--eval_text_column_name "text+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript" \
|
14 |
+
--eval_steps 5000 \
|
15 |
+
--save_steps 5000 \
|
16 |
+
--warmup_steps 500 \
|
17 |
+
--learning_rate 0.0001 \
|
18 |
+
--logging_steps 25 \
|
19 |
+
--save_total_limit 1 \
|
20 |
+
--max_steps 80000 \
|
21 |
+
--wer_threshold 10 \
|
22 |
+
--per_device_train_batch_size 64 \
|
23 |
+
--per_device_eval_batch_size 64 \
|
24 |
+
--dtype "bfloat16" \
|
25 |
+
--dataloader_num_workers 16 \
|
26 |
+
--cache_dir "/home/sanchitgandhi/.cache" \
|
27 |
+
--dataset_cache_dir "/home/sanchitgandhi/.cache" \
|
28 |
+
--output_dir "./" \
|
29 |
+
--timestamp_probability 0.2 \
|
30 |
+
--wandb_name "medium-24-2-tpu-timestamped-prob-0.2" \
|
31 |
+
--wandb_dir "/home/sanchitgandhi/.cache" \
|
32 |
+
--wandb_project "distil-whisper" \
|
33 |
+
--do_train \
|
34 |
+
--do_eval \
|
35 |
+
--use_scan \
|
36 |
+
--gradient_checkpointing \
|
37 |
+
--overwrite_output_dir \
|
38 |
+
--predict_with_generate \
|
39 |
+
--freeze_encoder \
|
40 |
+
--streaming \
|
41 |
+
--use_auth_token \
|
42 |
+
--push_to_hub
|