Lakoc commited on
Commit
702de8f
·
verified ·
1 Parent(s): 35c11e8

Upload DiCoWForConditionalGeneration

Browse files
Files changed (10) hide show
  1. README.md +199 -0
  2. config.json +78 -0
  3. config.py +85 -0
  4. decoding.py +397 -0
  5. encoder.py +364 -0
  6. generation.py +1770 -0
  7. generation_config.json +12 -0
  8. model.safetensors +3 -0
  9. modeling_dicow.py +362 -0
  10. utils.py +96 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "BUT-FIT/DiCoW_v3_MLC",
3
+ "activation_dropout": 0.0,
4
+ "activation_function": "gelu",
5
+ "additional_layer": false,
6
+ "additional_self_attention_layer": true,
7
+ "apply_fddt_to_n_layers": -1,
8
+ "apply_spec_augment": false,
9
+ "architectures": [
10
+ "DiCoWForConditionalGeneration"
11
+ ],
12
+ "attention_dropout": 0.0,
13
+ "auto_map": {
14
+ "AutoConfig": "config.DiCoWConfig",
15
+ "AutoModelForSpeechSeq2Seq": "modeling_dicow.DiCoWForConditionalGeneration"
16
+ },
17
+ "begin_suppress_tokens": [
18
+ 220,
19
+ 50256
20
+ ],
21
+ "blank_token_id": null,
22
+ "bos_token_id": 50257,
23
+ "classifier_proj_size": 256,
24
+ "ctc_loss_reduction": "mean",
25
+ "ctc_weight": 0.3,
26
+ "ctc_zero_infinity": false,
27
+ "d_model": 1280,
28
+ "decoder_attention_heads": 20,
29
+ "decoder_ffn_dim": 5120,
30
+ "decoder_layerdrop": 0.0,
31
+ "decoder_layers": 4,
32
+ "decoder_start_token_id": 50258,
33
+ "dropout": 0.0,
34
+ "encoder_attention_heads": 20,
35
+ "encoder_ffn_dim": 5120,
36
+ "encoder_layerdrop": 0.0,
37
+ "encoder_layers": 32,
38
+ "eos_token_id": 50257,
39
+ "fddt_bias_only": false,
40
+ "fddt_init": "disparagement",
41
+ "fddt_is_diagonal": true,
42
+ "fddt_use_non_target": true,
43
+ "fddt_use_overlap": true,
44
+ "fddt_use_silence": true,
45
+ "fddt_use_target": true,
46
+ "final_dropout": 0.0,
47
+ "forced_decoder_ids": null,
48
+ "init_std": 0.02,
49
+ "is_encoder_decoder": true,
50
+ "mask_feature_length": 10,
51
+ "mask_feature_min_masks": 0,
52
+ "mask_feature_prob": 0.0,
53
+ "mask_time_length": 10,
54
+ "mask_time_min_masks": 2,
55
+ "mask_time_prob": 0.05,
56
+ "max_source_positions": 1500,
57
+ "max_target_positions": 448,
58
+ "median_filter_width": 7,
59
+ "model_type": "whisper",
60
+ "mt_num_speakers": 1,
61
+ "n_soft_prompts": 16,
62
+ "non_target_fddt_value": 0.5,
63
+ "num_hidden_layers": 32,
64
+ "num_mel_bins": 128,
65
+ "pad_token_id": 50257,
66
+ "remove_timestamps_from_ctc": false,
67
+ "scale_embedding": false,
68
+ "scb_layers": -1,
69
+ "scb_method": null,
70
+ "sub_sample": true,
71
+ "torch_dtype": "float32",
72
+ "transformers_version": "4.42.0",
73
+ "use_cache": true,
74
+ "use_fddt": true,
75
+ "use_initial_fddt": true,
76
+ "use_weighted_layer_sum": false,
77
+ "vocab_size": 51866
78
+ }
config.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from transformers import WhisperConfig
6
+ from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput, Seq2SeqModelOutput
7
+
8
+
9
+ @dataclass
10
+ class Seq2SeqLMOutputLosses(Seq2SeqLMOutput):
11
+ enc_loss: Optional[torch.FloatTensor] = None
12
+ dec_loss: Optional[torch.FloatTensor] = None
13
+ encoder_logits: Optional[torch.FloatTensor] = None
14
+
15
+
16
+ @dataclass
17
+ class BaseModelOutputLogit(BaseModelOutput):
18
+ logits: Optional[torch.FloatTensor] = None
19
+
20
+
21
+ @dataclass
22
+ class Seq2SeqModelOutputLogit(Seq2SeqModelOutput):
23
+ encoder_logits: Optional[torch.FloatTensor] = None
24
+
25
+
26
+ class DiCoWConfig(WhisperConfig):
27
+ """This is a modified version of the `WhisperEncoder` model from the `transformers` library.
28
+ The model has been modified to support CTC loss computation in the forward pass."""
29
+
30
+ def __init__(
31
+ self,
32
+ ctc_loss_reduction: str = "mean",
33
+ final_dropout: float = 0.0,
34
+ ctc_zero_infinity: bool = False,
35
+ ctc_weight: float = 0.0,
36
+ blank_token_id: Optional[int] = None,
37
+ additional_layer: bool = False,
38
+ additional_self_attention_layer: bool = False,
39
+ sub_sample: bool = False,
40
+ use_fddt: bool = True,
41
+ fddt_is_diagonal: bool = True,
42
+ fddt_bias_only: bool = False,
43
+ fddt_use_silence: bool = True,
44
+ fddt_use_target: bool = True,
45
+ fddt_use_overlap: bool = True,
46
+ fddt_use_non_target: bool = True,
47
+ remove_timestamps_from_ctc: bool = False,
48
+ apply_fddt_to_n_layers: int = -1,
49
+ fddt_init: str = 'non-disturbing', # random, non-disturbing, dispargement
50
+ n_soft_prompts: int = 16,
51
+ mt_num_speakers: int = 1,
52
+ non_target_fddt_value: float = 0.0,
53
+ use_initial_fddt: bool = False,
54
+ scb_method: str = None,
55
+ scb_layers: int = -1,
56
+ **kwargs,
57
+ ):
58
+ super().__init__(**kwargs)
59
+ self.ctc_loss_reduction = ctc_loss_reduction
60
+ self.final_dropout = final_dropout
61
+ self.ctc_zero_infinity = ctc_zero_infinity
62
+ self.ctc_weight = ctc_weight
63
+ self.blank_token_id = blank_token_id
64
+ self.additional_layer = additional_layer
65
+ self.additional_self_attention_layer = additional_self_attention_layer
66
+ self.sub_sample = sub_sample
67
+ self.use_fddt = use_fddt
68
+ self.fddt_is_diagonal = fddt_is_diagonal
69
+ self.fddt_bias_only = fddt_bias_only
70
+ self.fddt_use_silence = fddt_use_silence
71
+ self.fddt_use_target = fddt_use_target
72
+ self.fddt_use_overlap = fddt_use_overlap
73
+ self.fddt_use_non_target = fddt_use_non_target
74
+ self.remove_timestamps_from_ctc = remove_timestamps_from_ctc
75
+ self.apply_fddt_to_n_layers = apply_fddt_to_n_layers
76
+ self.fddt_init = fddt_init
77
+ self.n_soft_prompts = n_soft_prompts
78
+ self.mt_num_speakers = mt_num_speakers
79
+ self.non_target_fddt_value = non_target_fddt_value
80
+ self.use_initial_fddt = use_initial_fddt
81
+ self.scb_method = scb_method
82
+ self.scb_layers = scb_layers
83
+
84
+
85
+ _HIDDEN_STATES_START_POSITION = 2
decoding.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # Copied from: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py
3
+ import itertools as it
4
+ from typing import List
5
+
6
+ import pandas as pd
7
+ import torch
8
+ from transformers import LogitsProcessor, PreTrainedTokenizer
9
+
10
+
11
+ class CTCPrefixScore(object):
12
+ """Compute CTC label sequence scores
13
+
14
+ which is based on Algorithm 2 in WATANABE et al.
15
+ "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
16
+ but extended to efficiently compute the label probabilities for multiple
17
+ hypotheses simultaneously
18
+ See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based
19
+ Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019.
20
+ """
21
+
22
+ def __init__(self, x, blank, eos):
23
+ self.logzero = -1e10
24
+ self.blank = blank
25
+ self.eos = eos
26
+ self.input_length = x.shape[1]
27
+ self.batch_size = x.shape[0]
28
+ self.x = x
29
+ self.device = x.device
30
+
31
+ # Preallocate `r` and `xs` tensors
32
+ # `num_labels` will be set dynamically in __call__ but preallocated with maximum capacity
33
+ self.max_num_labels = x.shape[2] # Set to a max value that can be dynamically resized
34
+ self.r = torch.full((self.batch_size, self.input_length, 2, self.max_num_labels), self.logzero,
35
+ device=self.device)
36
+ self.xs = torch.full((self.batch_size, self.input_length, self.max_num_labels), self.logzero,
37
+ device=self.device)
38
+
39
+ def initial_state(self):
40
+ """Obtain an initial CTC state."""
41
+ # Create initial CTC state tensor and use in-place operations to fill
42
+ r = torch.full((self.batch_size, self.input_length, 2), self.logzero, device=self.device)
43
+ r[..., 1] = torch.cumsum(self.x[..., self.blank], dim=1)
44
+ s = torch.zeros((self.batch_size, 1), device=self.device)
45
+
46
+ return r, s
47
+
48
+ def _resize_tensors(self, number_of_current_samples, num_labels):
49
+ if self.r.shape[0] != number_of_current_samples:
50
+ self.r = self.r[:number_of_current_samples, ...]
51
+ self.xs = self.xs[:number_of_current_samples, ...]
52
+
53
+ if self.r.shape[3] != num_labels:
54
+ self.r = self.r[:, :, :, :num_labels].fill_(self.logzero)
55
+ self.xs = self.xs[:, :, :num_labels].fill_(self.logzero)
56
+ else:
57
+ self.r.fill_(self.logzero)
58
+ self.xs.fill_(self.logzero)
59
+
60
+ def _initialize_r(self, decoded_len):
61
+ mask = (decoded_len == 0)
62
+ self.r[mask, 0, 0, :] = self.xs[mask, 0]
63
+
64
+ def _compute_log_phi(self, r_sum, cs, last, decoded_len, r_prev):
65
+ # Expand r_sum for num_labels and initialize log_phi
66
+ log_phi = r_sum[..., None].expand(-1, -1, cs.shape[1])
67
+
68
+ # Create mask for cases where `decoded_len > 0` and to identify where `c == last[i]` for all `i`
69
+ non_zero_mask = (decoded_len > 0)
70
+ label_match_mask = (cs == last.unsqueeze(1))
71
+
72
+ # Update log_phi where both `decoded_len > 0` and `c == last[i]`
73
+ log_phi = torch.where((non_zero_mask.unsqueeze(1) & label_match_mask)[:, None, :], r_prev[..., 1:2], log_phi)
74
+ return log_phi
75
+
76
+ def _compute_log_psi(self, decoded_len, log_phi, x_current):
77
+ """This function computes forward probabilities log(r_t^n(h)), log(r_t^b(h)),
78
+ and log prefix probabilities log(psi) for all labels in the batch.
79
+
80
+ :param decoded_len: tensor of shape (batch_size,) containing the length of the decoded sequence
81
+ :param log_phi: tensor of shape (batch_size, input_length, num_labels) containing the forward probabilities
82
+ :param x_current: tensor of shape (batch_size, input_length, num_labels) containing the input frame
83
+
84
+ :return log_psi: tensor of shape (batch_size,num_labels) containing the log prefix probabilities
85
+ """
86
+ B, T, V = log_phi.shape
87
+ start = torch.clamp(decoded_len, min=1) # Ensure start is at least 1 to avoid out-of-bounds
88
+
89
+ # Initialize log_psi with the start position of r[:, start - 1, 0, :]
90
+ log_psi = self.r[torch.arange(B), start - 1, 0, :]
91
+
92
+ # Mask for handling sequence lengths based on decoded_len
93
+ mask_t = torch.arange(1, T, device=decoded_len.device).expand(B, T - 1) >= decoded_len.unsqueeze(1)
94
+
95
+ # Accumulate log_psi only up to the last valid time step for each sequence
96
+ log_psi = torch.logaddexp(log_psi, torch.logsumexp(
97
+ torch.where(mask_t.unsqueeze(-1), log_phi[:, :-1] + self.xs[:, 1:], self.logzero), dim=1))
98
+
99
+ start = torch.clamp(decoded_len, 1)
100
+
101
+ # TODO: Vectorize this loop by compute suffix xs and multiplying with log_phi
102
+ # xs = self.xs[:,1:,:].clone()
103
+ # xs_cum = torch.cumsum(xs, dim=1)
104
+ # xs_cum_expanded = xs_cum.unsqueeze(1).repeat(1, T-1, 1, 1)
105
+ # xs_u = (xs_cum_expanded - torch.nn.functional.pad(xs_cum[:,:-1,:], (0,0,1,0), value=0).unsqueeze(2).repeat(1, 1,T-1,1)).permute(0,2,1,3)
106
+ #
107
+ # phis_new = log_phi[:,:-1].clone()
108
+ # phis_new[:, 0] = torch.logaddexp(phis_new[:, 0], self.r[:, 0, 0, :])
109
+ # phis_new = phis_new.unsqueeze(1).repeat(1, T-1, 1, 1)
110
+ # causal_mask = torch.ones((T-1,T-1), dtype=torch.bool, device=self.device).tril().unsqueeze(0).unsqueeze(-1).repeat(B,1,1,1)
111
+ # mask = causal_mask & mask_t.unsqueeze(2).unsqueeze(-1)
112
+ # r_zero = torch.logsumexp(torch.where(mask, xs_u + phis_new, self.logzero), dim=2)
113
+ # self.r[:,1:,0] = r_zero
114
+
115
+ for t in range(start.min(), self.input_length):
116
+ should_decode = decoded_len <= t
117
+ self.r[:, t, 0] = torch.logaddexp(self.r[:, t - 1, 0],
118
+ log_phi[:, t - 1]) + self.xs[:, t]
119
+ self.r[:, t, 1] = (
120
+ torch.logaddexp(self.r[:, t - 1, 0], self.r[:, t - 1, 1]) + x_current[:, t, self.blank][:, None]
121
+ )
122
+ if ~should_decode.any():
123
+ self.r[:, t] = torch.where(should_decode.unsqueeze(-1).unsqueeze(-1), self.r[:, t], self.logzero)
124
+
125
+ return log_psi
126
+
127
+ def _update_log_psi_with_eos(self, log_psi, cs, r_sum):
128
+ # Update log_psi for eos positions
129
+ eos_mask = (cs == self.eos)
130
+ log_psi[eos_mask] = r_sum[:, -1].unsqueeze(1).expand_as(log_psi)[eos_mask]
131
+
132
+ # Exclude blank probabilities if eos is not the blank
133
+ if self.eos != self.blank:
134
+ blank_mask = (cs == self.blank)
135
+ log_psi[blank_mask] = self.logzero
136
+ return log_psi
137
+
138
+ def __call__(self, y, cs, decoded_len, samples_to_be_decoded, r_prev):
139
+ """Compute CTC prefix scores for next labels
140
+
141
+ :param y : prefix label sequence
142
+ :param cs : array of next labels
143
+ :param r_prev: previous CTC state
144
+ :return ctc_scores, ctc_states
145
+ """
146
+ # initialize CTC states
147
+ # output_length = y.shape[1] - 1 # ignore sos
148
+ # new CTC states are prepared as a frame x (n or b) x n_labels tensor
149
+ # that corresponds to r_t^n(h) and r_t^b(h).
150
+
151
+ # Dynamically resize r and xs to match num_labels if necessary
152
+ num_labels = cs.shape[1]
153
+ number_of_current_samples = cs.shape[0]
154
+ self._resize_tensors(number_of_current_samples, num_labels)
155
+
156
+ # Create a view of the current input frame
157
+ x_current = self.x[samples_to_be_decoded]
158
+ self.xs = torch.gather(x_current, 2, cs.unsqueeze(1).expand(-1, self.input_length, -1))
159
+
160
+ # Initialize r for the first frame
161
+ self._initialize_r(decoded_len)
162
+
163
+ # prepare forward probabilities for the last label
164
+ r_sum = torch.logaddexp(r_prev[:, :, 0], r_prev[:, :, 1]) # log(r_t^n(g) + r_t^b(g))
165
+ last = y[:, -1]
166
+
167
+ # precompute log_phi
168
+ log_phi = self._compute_log_phi(r_sum, cs, last, decoded_len, r_prev)
169
+
170
+ # compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
171
+ # and log prefix probabilities log(psi)
172
+ log_psi = self._compute_log_psi(decoded_len, log_phi, x_current)
173
+
174
+ # get P(...eos|X) that ends with the prefix itself
175
+ log_psi = self._update_log_psi_with_eos(log_psi, cs, r_sum)
176
+
177
+ # return the log prefix probability and CTC states, where the label axis
178
+ # of the CTC states is moved to the first axis to slice it easily
179
+ return log_psi, self.r
180
+
181
+
182
+ class CTCRescorerLogitsProcessor(LogitsProcessor):
183
+ def __init__(
184
+ self,
185
+ encoder_logits: torch.FloatTensor,
186
+ encoder_output_lens: torch.Tensor,
187
+ blank_token_id: int,
188
+ pad_token_id: int,
189
+ eos_token_id: int,
190
+ bos_token_id: int,
191
+ tokenizer: PreTrainedTokenizer,
192
+ ctc_margin: int,
193
+ ctc_weight: float,
194
+ num_beams: int,
195
+ debug: bool = False,
196
+ ctc_tokens_to_score: int = 500
197
+ ):
198
+ super().__init__()
199
+ same_logits = torch.tensor(list((tokenizer.upper_cased_tokens.items())))
200
+
201
+ logits = torch.nn.functional.log_softmax(encoder_logits, dim=-1)
202
+ logits[..., same_logits[:, 1]] = logits[..., same_logits[:, 0]]
203
+
204
+ self.logits = logits
205
+
206
+ self.ctc_prefix_scorer = CTCPrefixScore(
207
+ self.logits,
208
+ blank_token_id,
209
+ eos_token_id,
210
+ )
211
+ self.batch_size = logits.shape[0]
212
+ self.input_length = logits.shape[1]
213
+ self.num_tokens = logits.shape[2]
214
+ self.device = logits.device
215
+ self.ctc_weight = ctc_weight
216
+ self.num_beams = num_beams
217
+ self.ctc_state_prev, self.ctc_score_prev = self.ctc_prefix_scorer.initial_state()
218
+ self.eos_token_id = eos_token_id
219
+ self.bos_token_id = bos_token_id
220
+ self.tokenizer = tokenizer
221
+ self.pad_token_id = pad_token_id
222
+ self.blank_token_id = blank_token_id
223
+ self.debug = False
224
+ self.first_timestamp_token_id = tokenizer.get_vocab()["<|0.00|>"]
225
+ self.tmp_ctc_scores = torch.empty((self.batch_size, self.num_tokens - 1), device=self.device)
226
+ self.tmp_ctc_states = torch.empty((self.batch_size, self.num_tokens - 1, self.input_length, 2),
227
+ device=self.device)
228
+ self.ctc_tokens_to_score = ctc_tokens_to_score
229
+
230
+ def analyze_predictions(self,
231
+ scores, ctc_scores, next_token_scores, input_ids, k=10):
232
+ print("\n" + "#" * 100)
233
+
234
+ batch_size = input_ids.shape[0]
235
+
236
+ best_att_ids = scores.topk(k=k, dim=1)
237
+ ctc_scores[:, self.first_timestamp_token_id:] = self.ctc_prefix_scorer.logzero
238
+ best_ctc_ids = ctc_scores.topk(k=k, dim=1)
239
+ best_ids = next_token_scores.topk(k=k, dim=1)
240
+
241
+ decoded_prefixes = self.tokenizer.batch_decode(
242
+ input_ids, decode_with_timestamps=True, skip_special_tokens=False
243
+ )
244
+
245
+ def prepare_and_decode(best_ids_tensor):
246
+ new_tensor = torch.zeros((batch_size, k * 2), dtype=torch.long)
247
+ new_tensor[:, 0::2] = best_ids_tensor.indices
248
+ new_tensor[:, 1::2] = self.tokenizer.vocab['#']
249
+
250
+ # Flatten to (batch_size * k, 2)
251
+ flat_tensor = new_tensor.view(-1, 2)
252
+ decoded = self.tokenizer.batch_decode(
253
+ flat_tensor, decode_with_timestamps=True, skip_special_tokens=False
254
+ )
255
+ # Reshape back to (batch_size, k)
256
+ decoded = [(decoded[i * k:(i + 1) * k]) for i in range(batch_size)]
257
+ return decoded
258
+
259
+ decoded_att = prepare_and_decode(best_att_ids)
260
+ decoded_ctc = prepare_and_decode(best_ctc_ids)
261
+ decoded_next = prepare_and_decode(best_ids)
262
+
263
+ for idx in range(batch_size):
264
+ print("-" * 80)
265
+ print(f"HYPOTHESIS {idx}")
266
+ print("\nPREFIX:")
267
+ print(decoded_prefixes[idx])
268
+
269
+ def print_with_pandas(tokens, scores, title):
270
+ df = pd.DataFrame([tokens, [f"{s.item():.2f}" for s in scores]])
271
+ df.index = [f"{title}", "Score"]
272
+ print(f"\n{title}:")
273
+ print(df.to_string(index=True, header=False))
274
+
275
+ print_with_pandas(decoded_att[idx], best_att_ids.values[idx], "ATT_TOKENS")
276
+ print_with_pandas(decoded_ctc[idx], best_ctc_ids.values[idx], "CTC_TOKENS")
277
+ print_with_pandas(decoded_next[idx], best_ids.values[idx], "NEXT_TOKENS")
278
+
279
+ print(f"\nCTC_EOS: {ctc_scores[idx, self.tokenizer.eos_token_id].item():.2f}")
280
+ print()
281
+
282
+ print("#" * 100)
283
+
284
+ def update_state(self, best_ids, beam_idx):
285
+ mask = best_ids < self.first_timestamp_token_id
286
+ self.ctc_state_prev = torch.where(mask.unsqueeze(-1).unsqueeze(-1),
287
+ self.tmp_ctc_states[beam_idx, best_ids],
288
+ self.ctc_state_prev[beam_idx])
289
+ self.ctc_score_prev = torch.where(mask.unsqueeze(-1),
290
+ self.tmp_ctc_scores[beam_idx, best_ids].unsqueeze(-1),
291
+ self.ctc_score_prev[beam_idx])
292
+
293
+ def __call__(self, input_ids_orig: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
294
+ input_ids = input_ids_orig.clone()
295
+
296
+ # Remove prefix from CTC scoring
297
+ if (input_ids[:, 0] != self.bos_token_id).any():
298
+ input_ids = torch.stack(
299
+ [row[(row == self.bos_token_id).nonzero(as_tuple=True)[0].item():] for row in input_ids])
300
+
301
+ # Remove task/lang/timestamp tokens from input_ids
302
+ input_prefix_len = len(self.tokenizer.prefix_tokens)
303
+ if input_prefix_len > 1:
304
+ input_ids = input_ids[:, input_prefix_len - 1:]
305
+
306
+ # Setup the first token to be the blank token(sos)
307
+ input_ids[:, 0] = self.blank_token_id
308
+
309
+ # If there is last token in input_ids timestamp replicate last non-timestamp token which could be potentially even the first token
310
+ decoded_len = torch.logical_and(input_ids <= self.first_timestamp_token_id,
311
+ input_ids != self.blank_token_id).sum(dim=1)
312
+ mask = torch.logical_and(input_ids[:, -1] >= self.first_timestamp_token_id,
313
+ input_ids[:, -1] != self.blank_token_id)
314
+ last_non_timestamp_token = torch.gather(input_ids, 1,
315
+ torch.logical_or(input_ids < self.first_timestamp_token_id,
316
+ input_ids == self.blank_token_id).sum(dim=1,
317
+ keepdim=True) - 1)
318
+ input_ids[mask, -1] = last_non_timestamp_token[mask, 0]
319
+
320
+ # If there is no eos token in the last position, we need to continue decoding
321
+ to_be_decoded = input_ids[:, -1] != self.eos_token_id
322
+ self.tmp_ctc_scores[:] = self.ctc_prefix_scorer.logzero
323
+
324
+ input_ids_local = input_ids[to_be_decoded]
325
+ ids_to_score = torch.topk(scores[:, :self.first_timestamp_token_id], k=self.ctc_tokens_to_score).indices
326
+
327
+ # always score EOS token if not present put on position of last id
328
+ is_eos_present = (ids_to_score == self.eos_token_id).any(dim=1)
329
+ ids_to_score[~is_eos_present, self.ctc_tokens_to_score - 1] = self.eos_token_id
330
+
331
+ decoded_len_local = decoded_len[to_be_decoded]
332
+
333
+ ctc_scores_local, ctc_states_local = self.ctc_prefix_scorer(input_ids_local, ids_to_score[to_be_decoded],
334
+ decoded_len_local, to_be_decoded,
335
+ self.ctc_state_prev[to_be_decoded])
336
+
337
+ # As the CTC scorer might run on subset of samples, we need to scatter the results back to the original batch
338
+ self.tmp_ctc_scores[to_be_decoded] = (self.tmp_ctc_scores[to_be_decoded]
339
+ .scatter(1, ids_to_score[to_be_decoded], ctc_scores_local))
340
+ self.tmp_ctc_states[to_be_decoded] = (self.tmp_ctc_states[to_be_decoded].permute(0, 2, 3, 1)
341
+ .scatter(3, ids_to_score[to_be_decoded].unsqueeze(1).unsqueeze(1)
342
+ .repeat(1, *ctc_states_local.shape[1:3], 1), ctc_states_local)
343
+ .permute(0, 3, 1, 2))
344
+
345
+ # Set the CTC score for the timestamp tokens to the maximum to prefer them over the rest
346
+ self.tmp_ctc_scores[:, self.first_timestamp_token_id:] = self.tmp_ctc_scores.max(dim=1).values[:, None]
347
+ ctc_scores = self.tmp_ctc_scores - self.ctc_score_prev
348
+
349
+ next_token_scores = (1 - self.ctc_weight) * scores + self.ctc_weight * ctc_scores
350
+
351
+ if self.debug:
352
+ self.analyze_predictions(scores, ctc_scores, next_token_scores, input_ids_orig)
353
+
354
+ return next_token_scores
355
+
356
+
357
+ class LogSoftmaxProcessor(LogitsProcessor):
358
+ def __init__(
359
+ self,
360
+ ):
361
+ super().__init__()
362
+
363
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
364
+ scores = torch.nn.functional.log_softmax(scores, dim=-1)
365
+ return scores
366
+
367
+
368
+ class GreedyCTCDecoder(torch.nn.Module):
369
+ def __init__(self, tokenizer, blank=0):
370
+ super().__init__()
371
+ self.blank = blank
372
+ self.tokenizer = tokenizer
373
+
374
+ def forward(self, emission: torch.Tensor) -> List[str]:
375
+ """Given a sequence emission over labels, get the best path
376
+ Args:
377
+ emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
378
+
379
+ Returns:
380
+ List[str]: The resulting transcript
381
+ """
382
+ indices = torch.argmax(emission, dim=-1) # [num_seq,]
383
+ indices = [torch.unique_consecutive(index, dim=-1) for index in indices]
384
+ indices = [index[index != self.blank] for index in indices]
385
+ indices = torch.nn.utils.rnn.pad_sequence(indices, batch_first=True,
386
+ padding_value=self.tokenizer.pad_token_id)
387
+ indices[indices >= len(self.tokenizer)] = self.tokenizer.unk_token_id
388
+ return indices
389
+
390
+
391
+ def ctc_greedy_decode(logits: torch.Tensor, blank, pad_token_id) -> torch.Tensor:
392
+ idxs = torch.argmax(logits, dim=-1)
393
+ for i, prediction in enumerate(idxs):
394
+ deduplicated = [k for k, g in it.groupby(prediction) if k != blank]
395
+ idxs[i, : len(deduplicated)] = torch.tensor(deduplicated)
396
+ idxs[i, len(deduplicated):] = pad_token_id
397
+ return idxs
encoder.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers.modeling_outputs import CausalLMOutput, BaseModelOutput
6
+ from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperEncoderLayer, WHISPER_ATTENTION_CLASSES
7
+
8
+ from .config import DiCoWConfig
9
+
10
+
11
+ class CustomLinear(nn.Linear):
12
+ def __init__(self, *args, init_eye_val=0.0, is_diagonal=False, **kwargs):
13
+ super().__init__(*args, **kwargs)
14
+ self.init_eye_val = init_eye_val
15
+
16
+
17
+ class CustomDiagonalLinear(nn.Module):
18
+ def __init__(self, d_model, bias=True, init_eye_val=0.0):
19
+ super().__init__()
20
+ self.init_eye_val = init_eye_val
21
+ self.weight = nn.Parameter(torch.full((d_model,), init_eye_val))
22
+ self.bias = nn.Parameter(torch.zeros(d_model)) if bias else None
23
+
24
+ def forward(self, input):
25
+ out = input * self.weight
26
+ if self.bias is not None:
27
+ out += self.bias
28
+ return out
29
+
30
+
31
+ class FDDT(nn.Module):
32
+ def __init__(self, d_model, non_target_rate=0.01, is_diagonal=False, bias_only=False, use_silence=True,
33
+ use_target=True, use_overlap=True, use_non_target=True, use_interaction=False,
34
+ scb_module: Optional[nn.Module] = None, ):
35
+ super().__init__()
36
+ if use_target:
37
+ self.target_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
38
+ CustomDiagonalLinear(d_model, bias=True, init_eye_val=1.0) if is_diagonal else CustomLinear(d_model,
39
+ d_model,
40
+ bias=True,
41
+ init_eye_val=1.0))
42
+ if use_non_target:
43
+ self.non_target_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
44
+ CustomDiagonalLinear(d_model, bias=True, init_eye_val=non_target_rate) if is_diagonal else CustomLinear(
45
+ d_model, d_model, bias=True, init_eye_val=non_target_rate))
46
+ if use_overlap:
47
+ self.overlap_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
48
+ CustomDiagonalLinear(d_model, bias=True, init_eye_val=1.0) if is_diagonal else CustomLinear(d_model,
49
+ d_model,
50
+ bias=True,
51
+ init_eye_val=1.0))
52
+ if use_silence:
53
+ self.silence_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
54
+ CustomDiagonalLinear(d_model, bias=True, init_eye_val=non_target_rate) if is_diagonal else CustomLinear(
55
+ d_model, d_model, bias=True, init_eye_val=non_target_rate))
56
+
57
+ if use_interaction:
58
+ self.scb = scb_module if scb_module is not None else (nn.Parameter(torch.zeros(d_model)) if bias_only else (
59
+ CustomDiagonalLinear(d_model, bias=True, init_eye_val=1.0) if is_diagonal else CustomLinear(
60
+ d_model, d_model, bias=True, init_eye_val=1.0)))
61
+
62
+ self.use_silence = use_silence
63
+ self.use_target = use_target
64
+ self.use_overlap = use_overlap
65
+ self.use_non_target = use_non_target
66
+ self.use_interaction = use_interaction
67
+ self.bias_only = bias_only
68
+
69
+ @staticmethod
70
+ def mask_out_non_interaction_signal(hidden_states, mask):
71
+ mask = torch.round(mask).bool()
72
+ masked_hidden_states = hidden_states * mask
73
+ return masked_hidden_states
74
+
75
+ def forward(self, hidden_states, stno_mask):
76
+ stno_mask = stno_mask.to(hidden_states.device)[..., None]
77
+ if self.bias_only:
78
+ if self.use_silence:
79
+ hidden_states += stno_mask[:, 0, ...] * self.silence_linear
80
+ if self.use_target:
81
+ hidden_states += stno_mask[:, 1, ...] * self.target_linear
82
+ if self.use_non_target:
83
+ hidden_states += stno_mask[:, 2, ...] * self.non_target_linear
84
+ if self.use_overlap:
85
+ hidden_states += stno_mask[:, 3, ...] * self.overlap_linear
86
+ if self.use_interaction:
87
+ hidden_states += stno_mask[:, 4, ...] * self.scb
88
+ else:
89
+ orig_hidden_states = hidden_states
90
+ hidden_states = (self.silence_linear(
91
+ orig_hidden_states) if self.use_silence else orig_hidden_states) * stno_mask[:, 0, :] + \
92
+ (self.target_linear(
93
+ orig_hidden_states) if self.use_target else orig_hidden_states) * stno_mask[:, 1, :] + \
94
+ (self.non_target_linear(
95
+ orig_hidden_states) if self.use_non_target else orig_hidden_states) * stno_mask[:, 2,
96
+ :] + \
97
+ (self.overlap_linear(
98
+ orig_hidden_states) if self.use_overlap else orig_hidden_states) * stno_mask[:, 3, :] + \
99
+ (self.scb(
100
+ self.mask_out_non_interaction_signal(orig_hidden_states,
101
+ stno_mask[:, 4, :])) * stno_mask[:, 4,
102
+ :] if self.use_interaction else (
103
+ 0 if stno_mask.size(
104
+ 1) == 4 else orig_hidden_states * stno_mask[:, 4,
105
+ :]))
106
+ return hidden_states
107
+
108
+
109
+ class DiCoWEncoder(WhisperEncoder):
110
+ config_class = DiCoWConfig
111
+
112
+ def __init__(self, config: DiCoWConfig):
113
+ super().__init__(config)
114
+ self.ctc_weight = config.ctc_weight
115
+ if config.additional_layer and self.ctc_weight > 0.0:
116
+ self.additional_layer = WhisperEncoderLayer(config)
117
+ if config.additional_self_attention_layer and self.ctc_weight > 0.0:
118
+ self.additional_self_attention_layer = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
119
+ embed_dim=config.d_model,
120
+ num_heads=config.encoder_attention_heads,
121
+ dropout=config.attention_dropout,
122
+ config=config,
123
+ )
124
+ if config.sub_sample and self.ctc_weight > 0.0:
125
+ self.subsample_conv1 = nn.Conv1d(
126
+ in_channels=config.d_model,
127
+ out_channels=config.d_model,
128
+ kernel_size=3,
129
+ stride=2,
130
+ padding=1,
131
+ bias=False,
132
+ )
133
+ self.subsample_conv2 = nn.Conv1d(
134
+ in_channels=config.d_model,
135
+ out_channels=config.d_model,
136
+ kernel_size=3,
137
+ stride=2,
138
+ padding=1,
139
+ bias=False,
140
+ )
141
+ if self.ctc_weight > 0.0:
142
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size + 1, bias=False)
143
+ self.final_dropout = nn.Dropout(config.final_dropout)
144
+ if config.use_fddt:
145
+ num_fddts = self.config.apply_fddt_to_n_layers if self.config.apply_fddt_to_n_layers != -1 else len(
146
+ self.layers)
147
+ self.initial_fddt = FDDT(config.d_model,
148
+ non_target_rate=config.non_target_fddt_value,
149
+ is_diagonal=config.fddt_is_diagonal,
150
+ bias_only=config.fddt_bias_only,
151
+ use_silence=config.fddt_use_silence,
152
+ use_target=config.fddt_use_target,
153
+ use_overlap=config.fddt_use_overlap,
154
+ use_non_target=config.fddt_use_non_target)
155
+ is_mt = config.mt_num_speakers > 1
156
+ num_scbs = (self.config.scb_layers if self.config.scb_layers != -1 else len(
157
+ self.layers)) if is_mt else 0
158
+ self.scbs_identity_layers = config.encoder_layers - num_scbs
159
+ self.fddts = nn.ModuleList([
160
+ FDDT(config.d_model,
161
+ non_target_rate=1.0,
162
+ is_diagonal=config.fddt_is_diagonal,
163
+ bias_only=config.fddt_bias_only,
164
+ use_silence=config.fddt_use_silence,
165
+ use_target=config.fddt_use_target,
166
+ use_overlap=config.fddt_use_overlap,
167
+ use_non_target=config.fddt_use_non_target,
168
+ use_interaction=is_mt,
169
+ )
170
+ for i in range(num_fddts)
171
+ ])
172
+ self.first_task_token = self.config.vocab_size - 30 * 50 - 1 - 6 # 30 seconds of 50 Hz timestamps -1 to get to 0.0 and -6 number of tasks
173
+ self.post_init()
174
+
175
+ @classmethod
176
+ def _load_pretrained_model(
177
+ cls,
178
+ model,
179
+ state_dict,
180
+ loaded_keys,
181
+ resolved_archive_file,
182
+ pretrained_model_name_or_path,
183
+ **kwargs
184
+ ):
185
+ for key in list(state_dict.keys()):
186
+ if key.startswith("encoder."):
187
+ state_dict[key[8:]] = state_dict.pop(key)
188
+ loaded_keys.remove(key)
189
+ loaded_keys.append(key[8:])
190
+ output = super()._load_pretrained_model(
191
+ model,
192
+ state_dict,
193
+ loaded_keys,
194
+ resolved_archive_file,
195
+ pretrained_model_name_or_path,
196
+ **kwargs
197
+ )
198
+ return output
199
+
200
+ def get_loss(self, logits, labels):
201
+ if labels.max() >= self.config.vocab_size:
202
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
203
+ if self.config.remove_timestamps_from_ctc:
204
+ labels = torch.nn.utils.rnn.pad_sequence([label[label < self.first_task_token] for label in labels],
205
+ padding_value=-100).T
206
+ input_lengths = torch.full((logits.shape[0],), fill_value=logits.shape[1],
207
+ device=logits.device)
208
+
209
+ # assuming that padded tokens are filled with -100
210
+ # when not being attended to
211
+ labels_mask = labels >= 0
212
+ target_lengths = labels_mask.sum(-1)
213
+ # flattened_targets = labels_enc.masked_select(labels_mask)
214
+
215
+ # ctc_loss doesn't support fp16
216
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
217
+
218
+ with torch.backends.cudnn.flags(enabled=True):
219
+ ctc_loss = nn.functional.ctc_loss(
220
+ log_probs,
221
+ labels,
222
+ input_lengths,
223
+ target_lengths,
224
+ blank=logits.shape[-1] - 1,
225
+ reduction=self.config.ctc_loss_reduction,
226
+ zero_infinity=True,
227
+ )
228
+ return ctc_loss
229
+
230
+ def forward(
231
+ self,
232
+ input_features,
233
+ attention_mask=None,
234
+ head_mask=None,
235
+ output_attentions=None,
236
+ output_hidden_states=None,
237
+ return_dict=None,
238
+ stno_mask=None,
239
+ per_group_sizes=None
240
+ ):
241
+ # For MT-ASR the input has shape (B X S) x F x T
242
+ # we can use torch.view(B, S, F, -1) to obtain
243
+ # new tensor with speaker dim
244
+ expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
245
+ if input_features.shape[-1] != expected_seq_length:
246
+ if input_features.shape[-1] > expected_seq_length:
247
+ return CausalLMOutput(
248
+ logits=None,
249
+ hidden_states=None,
250
+ attentions=None,
251
+ )
252
+ else:
253
+ raise ValueError(
254
+ f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
255
+ )
256
+
257
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
258
+ output_hidden_states = (
259
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
260
+ )
261
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
262
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features))
263
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
264
+
265
+ inputs_embeds = inputs_embeds.permute(0, 2, 1)
266
+ embed_pos = self.embed_positions.weight
267
+ if hasattr(self, "shift_embeds") and self.shift_embeds:
268
+ embed_pos = embed_pos[
269
+ torch.clamp(((stno_mask[:, 1, :] + stno_mask[:, 3, :]).cumsum(dim=-1) - 1), min=0).to(torch.long)]
270
+
271
+ if self.config.use_fddt:
272
+ inputs_embeds = self.initial_fddt(inputs_embeds, stno_mask)
273
+
274
+ hidden_states = inputs_embeds + embed_pos
275
+
276
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
277
+
278
+ encoder_states = () if output_hidden_states else None
279
+ all_attentions = () if output_attentions else None
280
+
281
+ # check if head_mask has a correct number of layers specified if desired
282
+ if head_mask is not None:
283
+ assert head_mask.size()[0] == (
284
+ len(self.layers)
285
+ ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
286
+
287
+ for idx, encoder_layer in enumerate(self.layers):
288
+ if output_hidden_states:
289
+ encoder_states = encoder_states + (hidden_states,)
290
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
291
+ to_drop = False
292
+ if self.training:
293
+ dropout_probability = torch.rand([])
294
+ if dropout_probability < self.layerdrop: # skip the layer
295
+ to_drop = True
296
+
297
+ if self.config.use_fddt and idx < len(self.fddts):
298
+ hidden_states = self.fddts[idx](hidden_states, stno_mask)
299
+
300
+ if to_drop:
301
+ layer_outputs = (None, None)
302
+ else:
303
+ if self.gradient_checkpointing and self.training:
304
+ layer_outputs = self._gradient_checkpointing_func(
305
+ encoder_layer.__call__,
306
+ hidden_states,
307
+ None,
308
+ (head_mask[idx] if head_mask is not None else None),
309
+ output_attentions,
310
+ )
311
+ else:
312
+ layer_outputs = encoder_layer(
313
+ hidden_states,
314
+ None,
315
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
316
+ output_attentions=output_attentions,
317
+ )
318
+
319
+ hidden_states = layer_outputs[0]
320
+
321
+ if output_attentions:
322
+ all_attentions = all_attentions + (layer_outputs[1],)
323
+
324
+ hidden_states = self.layer_norm(hidden_states)
325
+ if output_hidden_states:
326
+ encoder_states = encoder_states + (hidden_states,)
327
+
328
+ if not return_dict:
329
+ outputs = tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
330
+ else:
331
+ outputs = BaseModelOutput(
332
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
333
+ )
334
+
335
+ if hasattr(self, "additional_layer"):
336
+ inter_output, = self.additional_layer(
337
+ outputs.last_hidden_state,
338
+ attention_mask=None,
339
+ output_attentions=output_attentions,
340
+ layer_head_mask=None,
341
+ )
342
+ elif hasattr(self, "additional_self_attention_layer"):
343
+ inter_output, _, __ = self.additional_self_attention_layer(
344
+ outputs.last_hidden_state,
345
+ attention_mask=None,
346
+ output_attentions=output_attentions,
347
+ layer_head_mask=None,
348
+ )
349
+ else:
350
+ inter_output = outputs.last_hidden_state
351
+
352
+ inter_output = self.final_dropout(inter_output)
353
+ if hasattr(self, "subsample_conv2"):
354
+ inter_output = self.subsample_conv2(self.subsample_conv1(inter_output.transpose(1, 2))).transpose(1, 2)
355
+ if self.ctc_weight > 0.0:
356
+ logits = self.lm_head(inter_output)
357
+ else:
358
+ logits = None
359
+
360
+ return CausalLMOutput(
361
+ logits=logits,
362
+ hidden_states=outputs.hidden_states,
363
+ attentions=outputs.attentions,
364
+ )
generation.py ADDED
@@ -0,0 +1,1770 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+ from typing import Iterator
4
+ import warnings
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.utils.checkpoint
9
+ import torch.utils.checkpoint
10
+ from torch import nn
11
+ from torch.nn.utils.rnn import pad_sequence
12
+
13
+ from decimal import Decimal, ROUND_HALF_UP
14
+
15
+
16
+ from transformers import LogitsProcessorList, SuppressTokensLogitsProcessor, \
17
+ SuppressTokensAtBeginLogitsProcessor
18
+ from transformers.generation.configuration_utils import GenerationConfig
19
+ from transformers.generation.configuration_utils import GenerationMode
20
+ from transformers.generation.logits_process import (
21
+ LogitsProcessorList,
22
+ SuppressTokensAtBeginLogitsProcessor,
23
+ SuppressTokensLogitsProcessor, )
24
+ from transformers.generation.logits_process import WhisperNoSpeechDetection
25
+ from transformers.generation.stopping_criteria import (
26
+ StoppingCriteriaList,
27
+ )
28
+ from transformers.generation.utils import GenerateBeamOutput, BeamScorer, GenerateBeamDecoderOnlyOutput, \
29
+ stack_model_outputs, GenerateBeamEncoderDecoderOutput, _split_model_inputs, GenerateNonBeamOutput, \
30
+ GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput
31
+ from transformers.modeling_outputs import BaseModelOutput
32
+ from transformers.models.whisper.modeling_whisper import (
33
+ WhisperForConditionalGeneration,
34
+ )
35
+ from transformers.models.whisper.generation_whisper import _get_attr_from_logit_processors, _pad_to_max_length
36
+ from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
37
+ from transformers.utils import logging
38
+
39
+ from .utils import WhisperTimeStampLogitsProcessorCustom
40
+ from .decoding import CTCRescorerLogitsProcessor, LogSoftmaxProcessor
41
+
42
+ logging.set_verbosity_debug()
43
+ logger = logging.get_logger("transformers")
44
+
45
+
46
+ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
47
+ def _prepare_encoder_decoder_kwargs_for_generation(
48
+ self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name, generation_config,
49
+ ) -> Dict[str, Any]:
50
+ # self.encoder_output_lens = self._get_feat_extract_output_lengths(
51
+ # model_kwargs['attention_mask_enc'].sum(dim=1)
52
+ # ).int()
53
+ generation_config.output_hidden_states = True
54
+
55
+ # pylint: disable=no-memberva
56
+ model_kwargs = super()._prepare_encoder_decoder_kwargs_for_generation(
57
+ inputs_tensor, model_kwargs, model_input_name, generation_config
58
+ )
59
+ self.encoder_logits = model_kwargs["encoder_outputs"].logits
60
+
61
+ return model_kwargs
62
+
63
+ @staticmethod
64
+ def _expand_inputs_for_generation(
65
+ expand_size: int = 1,
66
+ is_encoder_decoder: bool = False,
67
+ input_ids: Optional[torch.LongTensor] = None,
68
+ **model_kwargs,
69
+ ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
70
+ """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
71
+
72
+ def _expand_dict_for_generation(dict_to_expand):
73
+ for key in dict_to_expand:
74
+ if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor) and key != "loss":
75
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
76
+ return dict_to_expand
77
+
78
+ if input_ids is not None:
79
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
80
+
81
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
82
+
83
+ if is_encoder_decoder:
84
+ if model_kwargs.get("encoder_outputs") is None:
85
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
86
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
87
+ if "hidden_states" in model_kwargs["encoder_outputs"]:
88
+ model_kwargs["encoder_outputs"]["hidden_states"] = tuple(
89
+ hidden_state.repeat_interleave(expand_size, dim=0) for hidden_state in
90
+ model_kwargs["encoder_outputs"]["hidden_states"]
91
+ )
92
+
93
+ return input_ids, model_kwargs
94
+
95
+ def generate(
96
+ self,
97
+ input_features: Optional[torch.Tensor] = None,
98
+ generation_config: Optional[GenerationConfig] = None,
99
+ logits_processor: Optional[LogitsProcessorList] = None,
100
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
101
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
102
+ synced_gpus: bool = False,
103
+ return_timestamps: Optional[bool] = None,
104
+ task: Optional[str] = None,
105
+ language: Optional[str] = None,
106
+ is_multilingual: Optional[bool] = None,
107
+ prompt_ids: Optional[torch.Tensor] = None,
108
+ prompt_condition_type: Optional[str] = None, # first-segment, all-segments
109
+ condition_on_prev_tokens: Optional[bool] = None,
110
+ temperature: Optional[Union[float, Tuple[float, ...]]] = None,
111
+ compression_ratio_threshold: Optional[float] = None,
112
+ logprob_threshold: Optional[float] = None,
113
+ no_speech_threshold: Optional[float] = None,
114
+ num_segment_frames: Optional[int] = None,
115
+ attention_mask: Optional[torch.Tensor] = None,
116
+ time_precision: float = 0.02,
117
+ return_token_timestamps: Optional[bool] = None,
118
+ return_segments: bool = False,
119
+ return_dict_in_generate: Optional[bool] = None,
120
+ assistant_model: Optional["PreTrainedModel"] = None,
121
+ **kwargs,
122
+ ):
123
+ if condition_on_prev_tokens:
124
+ raise NotImplementedError("Current version does not support conditioning")
125
+
126
+ gen_c, _ = self._prepare_generation_config(generation_config, **kwargs)
127
+ gen_mode = gen_c.get_generation_mode(assistant_model)
128
+
129
+ if gen_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.BEAM_SEARCH]:
130
+ raise ValueError(
131
+ f"Provided generation mode {gen_mode} is not supported"
132
+ f" for WhisperForConditionalGeneration with joint CTC decoding")
133
+
134
+ if "stno_mask" in kwargs:
135
+ self.stno_mask = kwargs["stno_mask"]
136
+ if "encoder_outputs" in kwargs:
137
+ self.encoder_logits = kwargs["encoder_outputs"].logits
138
+ # pylint: disable=no-member
139
+ # 0. deprecate old inputs
140
+ if "inputs" in kwargs:
141
+ input_features = kwargs.pop("inputs")
142
+ warnings.warn(
143
+ "The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
144
+ FutureWarning,
145
+ )
146
+
147
+ # 1. prepare generation config
148
+ generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
149
+
150
+ # 2. set global generate variables
151
+ input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
152
+ num_segment_frames = input_stride * self.config.max_source_positions
153
+ batch_size, total_input_frames = self._retrieve_total_input_frames(
154
+ input_features=input_features, input_stride=input_stride, kwargs=kwargs
155
+ )
156
+ is_shortform = total_input_frames <= num_segment_frames
157
+
158
+ if is_shortform:
159
+ # warn user of ignored inputs
160
+ self._maybe_warn_unused_inputs(
161
+ condition_on_prev_tokens=condition_on_prev_tokens,
162
+ temperature=temperature,
163
+ compression_ratio_threshold=compression_ratio_threshold,
164
+ logprob_threshold=logprob_threshold,
165
+ no_speech_threshold=no_speech_threshold,
166
+ total_input_frames=total_input_frames,
167
+ )
168
+
169
+ # 3. Make sure generation config is correctly set
170
+ # Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
171
+ self._set_return_outputs(
172
+ return_dict_in_generate=return_dict_in_generate,
173
+ return_token_timestamps=return_token_timestamps,
174
+ is_shortform=is_shortform,
175
+ logprob_threshold=logprob_threshold,
176
+ generation_config=generation_config,
177
+ )
178
+ self._set_return_timestamps(
179
+ return_timestamps=return_timestamps, is_shortform=is_shortform, generation_config=generation_config
180
+ )
181
+ self._set_language_and_task(
182
+ language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
183
+ )
184
+ self._set_num_frames(
185
+ return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
186
+ )
187
+ self._set_thresholds_and_condition(
188
+ generation_config=generation_config,
189
+ logprob_threshold=logprob_threshold,
190
+ compression_ratio_threshold=compression_ratio_threshold,
191
+ no_speech_threshold=no_speech_threshold,
192
+ condition_on_prev_tokens=condition_on_prev_tokens,
193
+ )
194
+ self._set_prompt_condition_type(
195
+ generation_config=generation_config,
196
+ prompt_condition_type=prompt_condition_type,
197
+ )
198
+
199
+ # pass self.config for backward compatibility
200
+ init_tokens = self._retrieve_init_tokens(
201
+ input_features,
202
+ batch_size=batch_size,
203
+ generation_config=generation_config,
204
+ config=self.config,
205
+ num_segment_frames=num_segment_frames,
206
+ kwargs=kwargs,
207
+ )
208
+ # passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
209
+ # where the input ids are handled explicitly by the generate method
210
+ self._check_decoder_input_ids(kwargs=kwargs)
211
+
212
+ # 3. Retrieve logits processors
213
+ device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
214
+ begin_index = init_tokens.shape[1]
215
+ logits_processor = self._retrieve_logit_processors(
216
+ generation_config=generation_config,
217
+ logits_processor=logits_processor,
218
+ begin_index=begin_index, # begin index is index of first generated decoder token
219
+ is_shortform=is_shortform,
220
+ num_beams=kwargs.get("num_beams", 1),
221
+ device=device,
222
+ )
223
+
224
+ # 5. If we're in shortform mode, simple generate the whole input at once and return the output
225
+ if is_shortform:
226
+ if temperature is not None:
227
+ generation_config.temperature = temperature
228
+
229
+ decoder_input_ids = kwargs.pop("decoder_input_ids", None)
230
+ if decoder_input_ids is None:
231
+ decoder_input_ids = init_tokens
232
+
233
+ if prompt_ids is not None:
234
+ decoder_input_ids = torch.cat(
235
+ [prompt_ids[None].repeat(decoder_input_ids.shape[0], 1), decoder_input_ids], dim=-1
236
+ )
237
+
238
+ max_new_tokens = generation_config.max_new_tokens if generation_config.max_new_tokens is not None else 0
239
+ if max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions:
240
+ raise ValueError(
241
+ f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` "
242
+ f"is {max_new_tokens}. Thus, the combined length of "
243
+ f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the "
244
+ f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. "
245
+ "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "
246
+ f"so that their combined length is less than {self.config.max_target_positions}."
247
+ )
248
+
249
+ outputs = super().generate(
250
+ input_features,
251
+ generation_config=generation_config,
252
+ logits_processor=logits_processor,
253
+ stopping_criteria=stopping_criteria,
254
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
255
+ synced_gpus=synced_gpus,
256
+ decoder_input_ids=decoder_input_ids,
257
+ **kwargs,
258
+ )
259
+
260
+ if generation_config.return_token_timestamps and hasattr(generation_config, "alignment_heads"):
261
+ outputs["token_timestamps"] = self._extract_token_timestamps(
262
+ outputs, generation_config.alignment_heads, num_frames=generation_config.num_frames
263
+ )
264
+
265
+ # print("\n".join(self.tokenizer.batch_decode(outputs,skip_special_tokens=True, decode_with_timestamps=True)))
266
+ return outputs
267
+
268
+ # 6. Else we're in longform mode which is more complex.
269
+ # We need to chunk the audio input depending on when the model generates timestamp tokens
270
+
271
+ # 6.1 Set and retrieve global longform generation variables
272
+ self._set_condition_on_prev_tokens(
273
+ condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config
274
+ )
275
+
276
+ timestamp_begin = generation_config.no_timestamps_token_id + 1
277
+ temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature
278
+ temperature = temperatures[0]
279
+ batch_size = input_features.shape[0]
280
+
281
+ max_frames, seek = self._retrieve_max_frames_and_seek(
282
+ batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames
283
+ )
284
+
285
+ # 6.2 Preppare running variables, list for generation
286
+ cur_bsz = batch_size
287
+ current_segments = self._prepare_segments(
288
+ prompt_ids=prompt_ids,
289
+ batch_size=batch_size,
290
+ generation_config=generation_config,
291
+ )
292
+
293
+ batch_idx_map = list(range(batch_size))
294
+ do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(batch_size)]
295
+
296
+ # 6.2 Transcribe audio until we reach the end of all input audios
297
+ while (seek < max_frames).any():
298
+ # 6.3 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop
299
+ # in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order
300
+ # to know which original audio is being decoded
301
+ # Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk
302
+ input_features, cur_bsz, batch_idx_map = self._maybe_reduce_batch(
303
+ input_features=input_features,
304
+ seek=seek,
305
+ max_frames=max_frames,
306
+ cur_bsz=cur_bsz,
307
+ batch_idx_map=batch_idx_map,
308
+ )
309
+ time_offset = seek * time_precision / input_stride
310
+ seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
311
+
312
+ # 6.4 cut out next 30s segment from input features
313
+ segment_input = self._get_input_segment(
314
+ input_features=input_features,
315
+ seek=seek,
316
+ seek_num_frames=seek_num_frames,
317
+ num_segment_frames=num_segment_frames,
318
+ cur_bsz=cur_bsz,
319
+ batch_idx_map=batch_idx_map,
320
+ )
321
+
322
+ # 6.5 prepare decoder input ids
323
+ suppress_tokens = _get_attr_from_logit_processors(
324
+ logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens"
325
+ )
326
+ decoder_input_ids, kwargs = self._prepare_decoder_input_ids(
327
+ cur_bsz=cur_bsz,
328
+ init_tokens=init_tokens,
329
+ current_segments=current_segments,
330
+ batch_idx_map=batch_idx_map,
331
+ do_condition_on_prev_tokens=do_condition_on_prev_tokens,
332
+ prompt_ids=prompt_ids,
333
+ generation_config=generation_config,
334
+ config=self.config,
335
+ device=segment_input.device,
336
+ suppress_tokens=suppress_tokens,
337
+ kwargs=kwargs,
338
+ )
339
+
340
+ # 6.6 set max new tokens or max length
341
+ self._set_max_new_tokens_and_length(
342
+ config=self.config,
343
+ decoder_input_ids=decoder_input_ids,
344
+ generation_config=generation_config,
345
+ )
346
+
347
+ # 6.7 Set current `begin_index` for all logit processors
348
+ for proc in logits_processor:
349
+ if hasattr(proc, "set_begin_index"):
350
+ proc.set_begin_index(decoder_input_ids.shape[-1])
351
+
352
+ # 6.8 Run generate with fallback
353
+ seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = self.generate_with_fallback(
354
+ segment_input=segment_input,
355
+ decoder_input_ids=decoder_input_ids,
356
+ cur_bsz=cur_bsz,
357
+ batch_idx_map=batch_idx_map,
358
+ seek=seek,
359
+ num_segment_frames=num_segment_frames,
360
+ max_frames=max_frames,
361
+ temperatures=temperatures,
362
+ generation_config=generation_config,
363
+ logits_processor=logits_processor,
364
+ stopping_criteria=stopping_criteria,
365
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
366
+ synced_gpus=synced_gpus,
367
+ return_token_timestamps=return_token_timestamps,
368
+ do_condition_on_prev_tokens=do_condition_on_prev_tokens,
369
+ kwargs=kwargs,
370
+ )
371
+
372
+ # 6.9 In every generated sequence, split by timestamp tokens and extract segments
373
+ if self.config.mt_num_speakers ==1:
374
+ for i, seek_sequence in enumerate(seek_sequences):
375
+ prev_i = batch_idx_map[i]
376
+
377
+ if should_skip[i]:
378
+ seek[prev_i] += seek_num_frames[prev_i]
379
+ continue
380
+
381
+ segments, segment_offset = self._retrieve_segment(
382
+ seek_sequence=seek_sequence,
383
+ seek_outputs=seek_outputs,
384
+ time_offset=time_offset,
385
+ timestamp_begin=timestamp_begin,
386
+ seek_num_frames=seek_num_frames,
387
+ time_precision=time_precision,
388
+ input_stride=input_stride,
389
+ prev_idx=prev_i,
390
+ idx=i,
391
+ return_token_timestamps=return_token_timestamps,
392
+ )
393
+
394
+ current_segments[prev_i] += segments
395
+ seek[prev_i] += segment_offset
396
+ else:
397
+ # We have to make sure all speakers are synchronized thus we have to find minumum of seeks that each instance like
398
+ for j, seek_seqs in enumerate([seek_sequences[i*self.config.mt_num_speakers:(i+1)*self.config.mt_num_speakers] for i in range(len(seek_sequences)//self.config.mt_num_speakers)]):
399
+ indexes = [j*self.config.mt_num_speakers + i for i in range(self.config.mt_num_speakers)]
400
+ prev_ids = [batch_idx_map[i] for i in indexes]
401
+
402
+ if all([should_skip[i] for i in indexes]):
403
+ for i, prev_i in zip(indexes, prev_ids):
404
+ seek[prev_i] += seek_num_frames[prev_i]
405
+ continue
406
+
407
+ segments, segment_offset = self._retrieve_segment_mt(
408
+ seek_sequences=seek_seqs,
409
+ seek_outputs=seek_outputs,
410
+ time_offset=time_offset,
411
+ timestamp_begin=timestamp_begin,
412
+ seek_num_frames=seek_num_frames,
413
+ time_precision=time_precision,
414
+ input_stride=input_stride,
415
+ prev_ids=prev_ids,
416
+ ids=indexes,
417
+ return_token_timestamps=return_token_timestamps,
418
+ )
419
+
420
+ for prev_i, i in zip(prev_ids, range(self.config.mt_num_speakers)):
421
+ current_segments[prev_i] += segments[i]
422
+ seek[prev_i] += segment_offset[i]
423
+
424
+ # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
425
+ # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
426
+ final_segments = (
427
+ [x[1:] for x in current_segments]
428
+ if (prompt_ids is not None and generation_config.prompt_condition_type == "first-segment")
429
+ else current_segments
430
+ )
431
+ sequences = _pad_to_max_length(
432
+ final_segments, generation_config.pad_token_id, device=self.device, padding="right"
433
+ )
434
+
435
+ # 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
436
+ output = {"sequences": sequences, "segments": final_segments}
437
+
438
+ self.encoder_logits = None
439
+
440
+ if isinstance(output, dict):
441
+ output = self._fix_timestamps_from_segmentation(output)
442
+
443
+ return output
444
+
445
+ @staticmethod
446
+ def _find_common_seek(sequences, seeks):
447
+ """
448
+ Finds the minimum seek that does not overlap with other sequences,
449
+ and falls back to (segment.start - 0.2) if needed. Assumes:
450
+ - 'seeks' is a list of (seek_time_int, sequence_index),
451
+ - seek_time_int is in timestamp * 100 format (e.g., 125.5s -> 12550).
452
+ """
453
+
454
+ def is_valid_seek(seek_time, exclude_seq_idx):
455
+ for idx, seq in enumerate(sequences):
456
+ if idx == exclude_seq_idx:
457
+ continue
458
+ for segment in seq:
459
+ start = getattr(segment, 'start', segment['start'])
460
+ end = getattr(segment, 'end', segment['end'])
461
+ if seek_time < start:
462
+ break # Segments are sorted by end
463
+ if start < seek_time < end:
464
+ return False
465
+ return True
466
+
467
+ # Step 1: Find minimum seek
468
+ # if all seek values are the same, return it immediately
469
+ seeks = [s if isinstance(s, int) else s.item() for s in seeks]
470
+ if len(set(seeks)) == 1:
471
+ return seeks[0]
472
+
473
+ min_seek_val = min(seeks)
474
+ min_seek_idx = seeks.index(min_seek_val)
475
+ min_seek_real = min_seek_val / 100
476
+
477
+ if is_valid_seek(min_seek_real, min_seek_idx):
478
+ return min_seek_val
479
+
480
+ # Step 2: Try fallback seeks from all sequences (segment.start - 0.1s)
481
+ fallback_seeks = set()
482
+ for idx, seq in enumerate(sequences):
483
+ for segment in seq:
484
+ start = getattr(segment, 'start', segment['start'])
485
+ if isinstance(start, torch.Tensor):
486
+ start = start.item()
487
+ candidate = round(start, 2)
488
+ fallback_seeks.add((candidate, idx, True))
489
+ end = getattr(segment, 'end', segment['end'])
490
+ if isinstance(end, torch.Tensor):
491
+ end = end.item()
492
+ if end < min_seek_real:
493
+ candidate = round(end, 2)
494
+ fallback_seeks.add((candidate, idx, True))
495
+
496
+ valid_fallbacks = [
497
+ (int(s * 100), idx, is_start) for s, idx, is_start in fallback_seeks
498
+ if is_valid_seek(s, min_seek_idx)
499
+ ]
500
+
501
+ if valid_fallbacks:
502
+ return max(valid_fallbacks)
503
+
504
+ # Step 3: Nothing valid
505
+ return 0
506
+
507
+ @staticmethod
508
+ def remove_segments_after_seek(sequences, seek, eps=100):
509
+ """
510
+ Keep only segments that finish before given timestamp.
511
+
512
+ Args:
513
+ sequences: List of lists, each containing segments (dict or object with 'start' and 'end').
514
+ seek: Integer seek timestamp (e.g., timestamp * 100).
515
+
516
+ Returns:
517
+ None. Modifies the sequences in-place.
518
+ """
519
+ return [[seg for seg in seq if (getattr(seg, 'end', seg['end']) * 100 <= seek +eps)] for seq in sequences]
520
+
521
+
522
+
523
+ @staticmethod
524
+ def _retrieve_segment_wo_seek(
525
+ seek_sequence,
526
+ seek_outputs,
527
+ time_offset,
528
+ timestamp_begin,
529
+ seek_num_frames,
530
+ time_precision,
531
+ input_stride,
532
+ prev_idx,
533
+ idx,
534
+ return_token_timestamps,
535
+ ):
536
+ # find the predicted "end of segment" predictions of Whisper
537
+ # "end of segment" predictions occur whenever Whisper predicts a timestamp token
538
+ timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
539
+ single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
540
+ timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
541
+ timestamp_segment_indices.add_(1)
542
+ token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
543
+
544
+ # If whisper predicted a "end of segment" via a timestep token, let's go ever each
545
+ # "end of segment" prediction and slice the decoding into segments accordingly
546
+ if len(timestamp_segment_indices) > 0:
547
+ # if the output contains two consecutive timestamp tokens
548
+ slices = timestamp_segment_indices.tolist()
549
+ segments = []
550
+ if single_timestamp_ending:
551
+ slices.append(len(seek_sequence))
552
+
553
+ last_slice = 0
554
+ # Add each segment to list of all segments
555
+ for current_slice in slices:
556
+ sliced_tokens = seek_sequence[last_slice:current_slice]
557
+ start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
558
+ end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin
559
+ segments.append(
560
+ {
561
+ "start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
562
+ "end": time_offset[prev_idx] + end_timestamp_pos * time_precision,
563
+ "tokens": sliced_tokens,
564
+ "result": seek_outputs[idx],
565
+ }
566
+ )
567
+ if return_token_timestamps:
568
+ segments[-1]["token_timestamps"] = (
569
+ token_timestamps[last_slice:current_slice] + time_offset[prev_idx]
570
+ )
571
+ last_slice = current_slice
572
+
573
+ if not single_timestamp_ending:
574
+ # generate all predictions after the last predicted "end of segment" and seek by 30s
575
+ sliced_tokens = seek_sequence[last_slice:]
576
+ start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
577
+ end_timestamp_pos = seek_num_frames[prev_idx] // 2
578
+ segments.append(
579
+ {
580
+ "start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
581
+ "end": time_offset[prev_idx] + end_timestamp_pos * time_precision,
582
+ "tokens": sliced_tokens,
583
+ "result": seek_outputs[idx],
584
+ }
585
+ )
586
+ segment_offset = seek_num_frames[prev_idx]
587
+ else:
588
+ # If whisper does not predict any "end of segment" token, then
589
+ # the whole decoding is considered a segment and we add it to the list of segments
590
+ timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
591
+ start_timestamp_pos = 0.0
592
+ last_timestamp_pos = seek_num_frames[prev_idx] // 2
593
+
594
+ if timestamps.numel() > 1:
595
+ start_timestamp_pos = timestamps[-2].item() - timestamp_begin
596
+ last_timestamp_pos = timestamps[-1].item() - timestamp_begin
597
+ elif timestamps.numel() == 1:
598
+ # no consecutive timestamps but it has a timestamp; use the last one.
599
+ start_timestamp_pos = timestamps[-1].item() - timestamp_begin
600
+ segments = [
601
+ {
602
+ "start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
603
+ "end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
604
+ "tokens": seek_sequence,
605
+ "result": seek_outputs[idx],
606
+ }
607
+ ]
608
+
609
+ segment_offset = seek_num_frames[prev_idx]
610
+
611
+ return segments, segment_offset
612
+
613
+ def _retrieve_segment_mt(
614
+ self,
615
+ seek_sequences,
616
+ seek_outputs,
617
+ time_offset,
618
+ timestamp_begin,
619
+ seek_num_frames,
620
+ time_precision,
621
+ input_stride,
622
+ prev_ids,
623
+ ids,
624
+ return_token_timestamps,
625
+ ):
626
+ sequences, seeks = [], []
627
+ for sequence, prev_id, idx in zip(seek_sequences, prev_ids, ids):
628
+ seq, seek = self._retrieve_segment(
629
+ seek_sequence=sequence,
630
+ seek_outputs=seek_outputs,
631
+ time_offset=time_offset,
632
+ timestamp_begin=timestamp_begin,
633
+ seek_num_frames=seek_num_frames,
634
+ time_precision=time_precision,
635
+ input_stride=input_stride,
636
+ prev_idx=prev_id,
637
+ idx=idx,
638
+ return_token_timestamps=return_token_timestamps,
639
+ )
640
+ sequences.append(seq)
641
+ seeks.append(seek +int(time_offset[prev_id] * 100))
642
+ # best_seek = self._find_common_seek(sequences, seeks)
643
+ best_seek = seeks[0]
644
+ # print(f"Best seek {best_seek}")
645
+ if best_seek - (min(time_offset[prev_ids]) *100) < 100:
646
+ # we cannot rollback, we have to decode segments as they are
647
+ sequences, seeks = [], []
648
+ for sequence, prev_id, idx in zip(seek_sequences, prev_ids, ids):
649
+ seq, seek = self._retrieve_segment_wo_seek(
650
+ seek_sequence=sequence,
651
+ seek_outputs=seek_outputs,
652
+ time_offset=time_offset,
653
+ timestamp_begin=timestamp_begin,
654
+ seek_num_frames=seek_num_frames,
655
+ time_precision=time_precision,
656
+ input_stride=input_stride,
657
+ prev_idx=prev_id,
658
+ idx=idx,
659
+ return_token_timestamps=return_token_timestamps,
660
+ )
661
+ sequences.append(seq)
662
+ seeks.append(seek)
663
+ return sequences, seeks
664
+
665
+ seqs_new = self.remove_segments_after_seek(sequences, best_seek)
666
+ seeks = [best_seek - int(min(time_offset[prev_ids]) * 100) for _ in seeks]
667
+ return seqs_new, seeks
668
+
669
+ def _beam_search(
670
+ self,
671
+ input_ids: torch.LongTensor,
672
+ beam_scorer: BeamScorer,
673
+ logits_processor: LogitsProcessorList,
674
+ stopping_criteria: StoppingCriteriaList,
675
+ generation_config: GenerationConfig,
676
+ synced_gpus: bool,
677
+ logits_warper: Optional[LogitsProcessorList] = None,
678
+ **model_kwargs,
679
+ ) -> Union[GenerateBeamOutput, torch.LongTensor]:
680
+ r"""
681
+ Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
682
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
683
+
684
+ Parameters:
685
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
686
+ The sequence used as a prompt for the generation.
687
+ beam_scorer (`BeamScorer`):
688
+ An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
689
+ sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
690
+ logits_processor (`LogitsProcessorList`):
691
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
692
+ used to modify the prediction scores of the language modeling head applied at each generation step.
693
+ stopping_criteria (`StoppingCriteriaList`:
694
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
695
+ used to tell if the generation loop should stop.
696
+ generation_config ([`~generation.GenerationConfig`]):
697
+ The generation configuration to be used as parametrization of the decoding method.
698
+ synced_gpus (`bool`):
699
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
700
+ logits_warper (`LogitsProcessorList`, *optional*):
701
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
702
+ to warp the prediction score distribution of the language modeling head applied before multinomial
703
+ sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
704
+ `generation_config`)
705
+ model_kwargs:
706
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
707
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
708
+
709
+ Return:
710
+ [`generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
711
+ `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
712
+ [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
713
+ `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
714
+ `model.config.is_encoder_decoder=True`.
715
+ """
716
+ # init values
717
+ pad_token_id = generation_config.pad_token_id
718
+ eos_token_id = generation_config.eos_token_id
719
+ output_attentions = generation_config.output_attentions
720
+ output_hidden_states = generation_config.output_hidden_states
721
+ output_scores = generation_config.output_scores
722
+ output_logits = generation_config.output_logits
723
+ return_dict_in_generate = generation_config.return_dict_in_generate
724
+ sequential = generation_config.low_memory
725
+ do_sample = generation_config.do_sample
726
+ if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
727
+ raise ValueError(
728
+ "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
729
+ f"{logits_warper})."
730
+ )
731
+
732
+ batch_size = len(beam_scorer._beam_hyps)
733
+ num_beams = beam_scorer.num_beams
734
+
735
+ batch_beam_size, cur_len = input_ids.shape
736
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
737
+
738
+ if num_beams * batch_size != batch_beam_size:
739
+ raise ValueError(
740
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
741
+ )
742
+
743
+ # init attention / hidden states / scores tuples
744
+ scores = () if (return_dict_in_generate and output_scores) else None
745
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
746
+ beam_indices = (
747
+ tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
748
+ )
749
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
750
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
751
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
752
+
753
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
754
+ if return_dict_in_generate and self.config.is_encoder_decoder:
755
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
756
+ encoder_hidden_states = (
757
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
758
+ )
759
+
760
+ # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
761
+ # of the first beam are considered to avoid sampling the exact same tokens across all beams.
762
+ beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
763
+ beam_scores[:, 1:] = -1e9
764
+ beam_scores = beam_scores.view((batch_size * num_beams,))
765
+
766
+ this_peer_finished = False
767
+
768
+ decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
769
+
770
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
771
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
772
+
773
+ # if sequential is True, split the input to batches of batch_size and run sequentially
774
+ if sequential:
775
+ if any(
776
+ model_name in self.__class__.__name__.lower()
777
+ for model_name in [
778
+ "fsmt",
779
+ "reformer",
780
+ "bloom",
781
+ "ctrl",
782
+ "gpt_bigcode",
783
+ "transo_xl",
784
+ "xlnet",
785
+ "cpm",
786
+ "jamba",
787
+ ]
788
+ ):
789
+ raise RuntimeError(
790
+ f"Currently generation for {self.__class__.__name__} is not supported "
791
+ f"for `low_memory beam_search`. Please open an issue on GitHub if you need this feature."
792
+ )
793
+
794
+ inputs_per_sub_batches = _split_model_inputs(
795
+ model_inputs, split_size=batch_size, full_batch_size=batch_beam_size
796
+ )
797
+ outputs_per_sub_batch = [
798
+ self(
799
+ **inputs_per_sub_batch,
800
+ return_dict=True,
801
+ output_attentions=output_attentions,
802
+ output_hidden_states=output_hidden_states,
803
+ )
804
+ for inputs_per_sub_batch in inputs_per_sub_batches
805
+ ]
806
+
807
+ outputs = stack_model_outputs(outputs_per_sub_batch)
808
+
809
+ else: # Unchanged original behavior
810
+ outputs = self(
811
+ **model_inputs,
812
+ return_dict=True,
813
+ output_attentions=output_attentions,
814
+ output_hidden_states=output_hidden_states,
815
+ )
816
+
817
+ if synced_gpus and this_peer_finished:
818
+ cur_len = cur_len + 1
819
+ continue # don't waste resources running the code we don't need
820
+
821
+ next_token_logits = outputs.logits[:, -1, :]
822
+ next_token_scores = nn.functional.log_softmax(
823
+ next_token_logits, dim=-1
824
+ ) # (batch_size * num_beams, vocab_size)
825
+
826
+ next_token_scores_processed = logits_processor(input_ids, next_token_scores)
827
+ if do_sample:
828
+ next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed)
829
+ next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
830
+ next_token_scores_processed
831
+ )
832
+
833
+ # Store scores, attentions and hidden_states when required
834
+ if return_dict_in_generate:
835
+ if output_scores:
836
+ scores += (next_token_scores_processed,)
837
+ if output_logits:
838
+ raw_logits += (next_token_logits,)
839
+ if output_attentions:
840
+ decoder_attentions += (
841
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
842
+ )
843
+ if self.config.is_encoder_decoder:
844
+ cross_attentions += (outputs.cross_attentions,)
845
+ if output_hidden_states:
846
+ decoder_hidden_states += (
847
+ (outputs.decoder_hidden_states,)
848
+ if self.config.is_encoder_decoder
849
+ else (outputs.hidden_states,)
850
+ )
851
+
852
+ # reshape for beam search
853
+ vocab_size = next_token_scores.shape[-1]
854
+ next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
855
+
856
+ # Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1
857
+ # non eos token per beam.
858
+ n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
859
+ n_tokens_to_keep = max(2, 1 + n_eos_tokens) * num_beams
860
+ if do_sample:
861
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
862
+ next_tokens = torch.multinomial(probs, num_samples=n_tokens_to_keep)
863
+ next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
864
+ next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
865
+ next_tokens = torch.gather(next_tokens, -1, _indices)
866
+ else:
867
+ next_token_scores, next_tokens = torch.topk(
868
+ next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True
869
+ )
870
+
871
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
872
+ next_tokens = next_tokens % vocab_size
873
+
874
+ # stateless
875
+ beam_outputs = beam_scorer.process(
876
+ input_ids,
877
+ next_token_scores,
878
+ next_tokens,
879
+ next_indices,
880
+ pad_token_id=pad_token_id,
881
+ eos_token_id=eos_token_id,
882
+ beam_indices=beam_indices,
883
+ decoder_prompt_len=decoder_prompt_len,
884
+ )
885
+
886
+ beam_scores = beam_outputs["next_beam_scores"]
887
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
888
+ beam_idx = beam_outputs["next_beam_indices"]
889
+
890
+ # Based on the beam idx and next tokens reshuffle the ctc prev states and scores
891
+ if hasattr(self, "ctc_rescorer"):
892
+ self.ctc_rescorer.update_state(beam_next_tokens, beam_idx)
893
+ input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
894
+
895
+ model_kwargs = self._update_model_kwargs_for_generation(
896
+ outputs,
897
+ model_kwargs,
898
+ is_encoder_decoder=self.config.is_encoder_decoder,
899
+ )
900
+ if model_kwargs.get("past_key_values", None) is not None:
901
+ model_kwargs["past_key_values"] = self._temporary_reorder_cache(
902
+ model_kwargs["past_key_values"], beam_idx
903
+ )
904
+
905
+ if return_dict_in_generate and output_scores:
906
+ beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
907
+
908
+ # increase cur_len
909
+ cur_len = cur_len + 1
910
+
911
+ if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
912
+ this_peer_finished = True
913
+
914
+ sequence_outputs = beam_scorer.finalize(
915
+ input_ids,
916
+ beam_scores,
917
+ next_tokens,
918
+ next_indices,
919
+ pad_token_id=pad_token_id,
920
+ eos_token_id=eos_token_id,
921
+ max_length=stopping_criteria.max_length,
922
+ beam_indices=beam_indices,
923
+ decoder_prompt_len=decoder_prompt_len,
924
+ )
925
+
926
+ if return_dict_in_generate:
927
+ if not output_scores:
928
+ sequence_outputs["sequence_scores"] = None
929
+
930
+ if self.config.is_encoder_decoder:
931
+ return GenerateBeamEncoderDecoderOutput(
932
+ sequences=sequence_outputs["sequences"],
933
+ sequences_scores=sequence_outputs["sequence_scores"],
934
+ scores=scores,
935
+ logits=raw_logits,
936
+ beam_indices=sequence_outputs["beam_indices"],
937
+ encoder_attentions=encoder_attentions,
938
+ encoder_hidden_states=encoder_hidden_states,
939
+ decoder_attentions=decoder_attentions,
940
+ cross_attentions=cross_attentions,
941
+ decoder_hidden_states=decoder_hidden_states,
942
+ past_key_values=model_kwargs.get("past_key_values"),
943
+ )
944
+ else:
945
+ return GenerateBeamDecoderOnlyOutput(
946
+ sequences=sequence_outputs["sequences"],
947
+ sequences_scores=sequence_outputs["sequence_scores"],
948
+ scores=scores,
949
+ logits=raw_logits,
950
+ beam_indices=sequence_outputs["beam_indices"],
951
+ attentions=decoder_attentions,
952
+ hidden_states=decoder_hidden_states,
953
+ past_key_values=model_kwargs.get("past_key_values"),
954
+ )
955
+ else:
956
+ return sequence_outputs["sequences"]
957
+
958
+ def _sample(
959
+ self,
960
+ input_ids: torch.LongTensor,
961
+ logits_processor: LogitsProcessorList,
962
+ stopping_criteria: StoppingCriteriaList,
963
+ generation_config: GenerationConfig,
964
+ synced_gpus: bool,
965
+ streamer: Optional["BaseStreamer"],
966
+ logits_warper: Optional[LogitsProcessorList] = None,
967
+ **model_kwargs,
968
+ ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
969
+ r"""
970
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
971
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
972
+
973
+ Parameters:
974
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
975
+ The sequence used as a prompt for the generation.
976
+ logits_processor (`LogitsProcessorList`):
977
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
978
+ used to modify the prediction scores of the language modeling head applied at each generation step.
979
+ stopping_criteria (`StoppingCriteriaList`):
980
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
981
+ used to tell if the generation loop should stop.
982
+ generation_config ([`~generation.GenerationConfig`]):
983
+ The generation configuration to be used as parametrization of the decoding method.
984
+ synced_gpus (`bool`):
985
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
986
+ streamer (`BaseStreamer`, *optional*):
987
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
988
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
989
+ logits_warper (`LogitsProcessorList`, *optional*):
990
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
991
+ to warp the prediction score distribution of the language modeling head applied before multinomial
992
+ sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
993
+ `generation_config`)
994
+ model_kwargs:
995
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
996
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
997
+
998
+ Return:
999
+ [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
1000
+ A `torch.LongTensor` containing the generated tokens (default behaviour) or a
1001
+ [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
1002
+ `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
1003
+ `model.config.is_encoder_decoder=True`.
1004
+ """
1005
+ # init values
1006
+ pad_token_id = generation_config.pad_token_id
1007
+ output_attentions = generation_config.output_attentions
1008
+ output_hidden_states = generation_config.output_hidden_states
1009
+ output_scores = generation_config.output_scores
1010
+ output_logits = generation_config.output_logits
1011
+ return_dict_in_generate = generation_config.return_dict_in_generate
1012
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
1013
+ do_sample = generation_config.do_sample
1014
+ if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
1015
+ raise ValueError(
1016
+ "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
1017
+ f"{logits_warper})."
1018
+ )
1019
+
1020
+ # init attention / hidden states / scores tuples
1021
+ scores = () if (return_dict_in_generate and output_scores) else None
1022
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
1023
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
1024
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
1025
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
1026
+
1027
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
1028
+ if return_dict_in_generate and self.config.is_encoder_decoder:
1029
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
1030
+ encoder_hidden_states = (
1031
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
1032
+ )
1033
+
1034
+ # keep track of which sequences are already finished
1035
+ batch_size = input_ids.shape[0]
1036
+ this_peer_finished = False
1037
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
1038
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
1039
+
1040
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
1041
+ # prepare model inputs
1042
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1043
+
1044
+ # forward pass to get next token
1045
+ outputs = self(
1046
+ **model_inputs,
1047
+ return_dict=True,
1048
+ output_attentions=output_attentions,
1049
+ output_hidden_states=output_hidden_states,
1050
+ )
1051
+
1052
+ if synced_gpus and this_peer_finished:
1053
+ continue # don't waste resources running the code we don't need
1054
+
1055
+ next_token_logits = outputs.logits[:, -1, :]
1056
+
1057
+ # pre-process distribution
1058
+ next_token_scores = logits_processor(input_ids, next_token_logits)
1059
+ if do_sample:
1060
+ next_token_scores = logits_warper(input_ids, next_token_scores)
1061
+
1062
+ # Store scores, attentions and hidden_states when required
1063
+ if return_dict_in_generate:
1064
+ if output_scores:
1065
+ scores += (next_token_scores,)
1066
+ if output_logits:
1067
+ raw_logits += (next_token_logits,)
1068
+ if output_attentions:
1069
+ decoder_attentions += (
1070
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
1071
+ )
1072
+ if self.config.is_encoder_decoder:
1073
+ cross_attentions += (outputs.cross_attentions,)
1074
+
1075
+ if output_hidden_states:
1076
+ decoder_hidden_states += (
1077
+ (outputs.decoder_hidden_states,)
1078
+ if self.config.is_encoder_decoder
1079
+ else (outputs.hidden_states,)
1080
+ )
1081
+
1082
+ # token selection
1083
+ if do_sample:
1084
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
1085
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1086
+ else:
1087
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
1088
+
1089
+ # finished sentences should have their next token be a padding token
1090
+ if has_eos_stopping_criteria:
1091
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
1092
+
1093
+ # Based on the next tokens select the ctc prev states and scores
1094
+ if hasattr(self, "ctc_rescorer"):
1095
+ self.ctc_rescorer.update_state(next_tokens, torch.arange(next_tokens.shape[0]))
1096
+
1097
+ # update generated ids, model inputs, and length for next step
1098
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1099
+ if streamer is not None:
1100
+ streamer.put(next_tokens.cpu())
1101
+ model_kwargs = self._update_model_kwargs_for_generation(
1102
+ outputs,
1103
+ model_kwargs,
1104
+ is_encoder_decoder=self.config.is_encoder_decoder,
1105
+ )
1106
+
1107
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
1108
+ this_peer_finished = unfinished_sequences.max() == 0
1109
+
1110
+ if streamer is not None:
1111
+ streamer.end()
1112
+
1113
+ if return_dict_in_generate:
1114
+ if self.config.is_encoder_decoder:
1115
+ return GenerateEncoderDecoderOutput(
1116
+ sequences=input_ids,
1117
+ scores=scores,
1118
+ logits=raw_logits,
1119
+ encoder_attentions=encoder_attentions,
1120
+ encoder_hidden_states=encoder_hidden_states,
1121
+ decoder_attentions=decoder_attentions,
1122
+ cross_attentions=cross_attentions,
1123
+ decoder_hidden_states=decoder_hidden_states,
1124
+ past_key_values=model_kwargs.get("past_key_values"),
1125
+ )
1126
+ else:
1127
+ return GenerateDecoderOnlyOutput(
1128
+ sequences=input_ids,
1129
+ scores=scores,
1130
+ logits=raw_logits,
1131
+ attentions=decoder_attentions,
1132
+ hidden_states=decoder_hidden_states,
1133
+ past_key_values=model_kwargs.get("past_key_values"),
1134
+ )
1135
+ else:
1136
+ return input_ids
1137
+
1138
+ def prepare_kwargs_for_generate(self,
1139
+ segment_input,
1140
+ cur_bsz,
1141
+ batch_idx_map,
1142
+ seek,
1143
+ num_segment_frames,
1144
+ max_frames,
1145
+ kwargs):
1146
+ kwargs["attention_mask_enc"] = torch.ones(cur_bsz, segment_input.size(-1), device=segment_input.device)
1147
+ seek_vad = seek // 2
1148
+ num_frames_vad = num_segment_frames // 2
1149
+ max_frames_vad = max_frames // 2
1150
+ seek_num_frames = (max_frames_vad - seek_vad).clamp(max=num_frames_vad)
1151
+
1152
+ stno_masks = []
1153
+ for i in range(cur_bsz):
1154
+ prev_i = batch_idx_map[i]
1155
+ segment_input_slice = kwargs["stno_mask"][prev_i: prev_i + 1, :,
1156
+ seek_vad[prev_i]: seek_vad[prev_i] + seek_num_frames[prev_i]]
1157
+
1158
+ if segment_input_slice.shape[-1] < num_frames_vad:
1159
+ orig_len = segment_input_slice.shape[-1]
1160
+ # pad to 3000 if necessary
1161
+ segment_input_slice = torch.nn.functional.pad(
1162
+ segment_input_slice, pad=(0, num_frames_vad - orig_len)
1163
+ )
1164
+ # set corresponding padding tokens to 1 in vad mask representing silence
1165
+ segment_input_slice[0, 0, orig_len:] = 1.0
1166
+
1167
+ stno_masks.append(segment_input_slice)
1168
+ kwargs["stno_mask"] = torch.cat(stno_masks, dim=0)
1169
+ self.stno_mask_seek = kwargs["stno_mask"]
1170
+
1171
+ if "per_group_sizes" in kwargs:
1172
+ group_sizes = kwargs["per_group_sizes"].clone()
1173
+ group_sizes[:] = 0
1174
+ cummulative_group_sizes = (
1175
+ kwargs["per_group_sizes"].max().repeat(kwargs["per_group_sizes"].shape[0])).cumsum(dim=0)
1176
+ for i in batch_idx_map:
1177
+ group_idx = (cummulative_group_sizes > i).nonzero().min()
1178
+ group_sizes[group_idx] += 1
1179
+ kwargs["per_group_sizes"] = group_sizes
1180
+
1181
+ if self.vad_seek_callback is not None:
1182
+ self.vad_seek_callback(kwargs["stno_mask"])
1183
+ return kwargs
1184
+
1185
+ def generate_with_fallback(
1186
+ self,
1187
+ segment_input,
1188
+ decoder_input_ids,
1189
+ cur_bsz,
1190
+ batch_idx_map,
1191
+ seek,
1192
+ num_segment_frames,
1193
+ max_frames,
1194
+ temperatures,
1195
+ generation_config,
1196
+ logits_processor,
1197
+ stopping_criteria,
1198
+ prefix_allowed_tokens_fn,
1199
+ synced_gpus,
1200
+ return_token_timestamps,
1201
+ do_condition_on_prev_tokens,
1202
+ kwargs,
1203
+ ):
1204
+ kwargs = copy.copy(kwargs)
1205
+ kwargs = self.prepare_kwargs_for_generate(segment_input, cur_bsz, batch_idx_map, seek, num_segment_frames,
1206
+ max_frames, kwargs)
1207
+ seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = super().generate_with_fallback(
1208
+ segment_input,
1209
+ decoder_input_ids,
1210
+ cur_bsz,
1211
+ batch_idx_map,
1212
+ seek,
1213
+ num_segment_frames,
1214
+ max_frames,
1215
+ temperatures,
1216
+ generation_config,
1217
+ logits_processor,
1218
+ stopping_criteria,
1219
+ prefix_allowed_tokens_fn,
1220
+ synced_gpus,
1221
+ return_token_timestamps,
1222
+ do_condition_on_prev_tokens,
1223
+ kwargs,
1224
+ )
1225
+ self.stno_mask_seek =None
1226
+
1227
+ # for i, seq in enumerate(seek_outputs):
1228
+ # print(f"Sequence {i}: {self.tokenizer.decode(seq, decode_with_timestamps=True)}")
1229
+ # print("-"*50)
1230
+
1231
+ return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens
1232
+
1233
+ def _retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs):
1234
+ def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
1235
+ """short function to replace num with a itr in lst"""
1236
+ found = any(i in lst for i in itr)
1237
+ if found:
1238
+ lst = [num if i in itr else i for i in lst]
1239
+ else:
1240
+ lst.append(num)
1241
+ return lst
1242
+
1243
+ def language_to_id(language: str) -> int:
1244
+ language = language.lower()
1245
+ if language in generation_config.lang_to_id.keys():
1246
+ language_token = language
1247
+ elif language in TO_LANGUAGE_CODE.keys():
1248
+ language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
1249
+ elif language in TO_LANGUAGE_CODE.values():
1250
+ language_token = f"<|{language}|>"
1251
+ else:
1252
+ is_language_code = len(language) == 2
1253
+ raise ValueError(
1254
+ f"Unsupported language: {language}. Language should be one of:"
1255
+ f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
1256
+ )
1257
+ if language_token not in generation_config.lang_to_id:
1258
+ raise ValueError(
1259
+ f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`."
1260
+ "(You should just add it to the generation config)"
1261
+ )
1262
+
1263
+ return generation_config.lang_to_id[language_token]
1264
+
1265
+ task = getattr(generation_config, "task", None)
1266
+ language = getattr(generation_config, "language", None)
1267
+
1268
+ forced_decoder_ids = generation_config.forced_decoder_ids
1269
+ if forced_decoder_ids is not None:
1270
+ if language is None and task is None and forced_decoder_ids[0][1] is None:
1271
+ logger.warning_once(
1272
+ "Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
1273
+ "This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`."
1274
+ )
1275
+ elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
1276
+ forced_decoder_ids = config.forced_decoder_ids
1277
+
1278
+ elif forced_decoder_ids is not None and language is not None:
1279
+ logger.info(
1280
+ f"You have passed language={language}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of language={language}."
1281
+ )
1282
+ forced_decoder_ids = None
1283
+
1284
+ init_tokens = [generation_config.decoder_start_token_id]
1285
+
1286
+ # Update init_tokens with languages
1287
+ lang_ids = None
1288
+
1289
+ if forced_decoder_ids is not None:
1290
+ return forced_decoder_ids
1291
+
1292
+ # from v4.39 the forced decoder ids are always None in favour of decoder input ids
1293
+ generation_config.forced_decoder_ids = None
1294
+
1295
+ is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
1296
+
1297
+ # Make sure language is a list of strings of the correct length
1298
+ if isinstance(language, (list, tuple)):
1299
+ if any(l is None for l in language):
1300
+ raise TypeError(
1301
+ "Expected `language` to be `None`, a single string (e.g. `'en'`), or a list of strings with length equal to the batch size (e.g. `('en', 'fr')` for a batch size of 2). Got a list containing `None`."
1302
+ )
1303
+ if len(language) != batch_size:
1304
+ raise ValueError(
1305
+ "When passing a list of languages, the length of the list must match the batch size. "
1306
+ f"Expected length of {batch_size}, but got {len(language)} languages."
1307
+ )
1308
+ languages = language
1309
+ elif language is None:
1310
+ # Language will be detected for each item in batch
1311
+ languages = [None] * batch_size
1312
+ else:
1313
+ languages = [language] # Use a length-1 list now, broadcast later
1314
+
1315
+ # Separate init_tokens for each language
1316
+ init_tokens = [copy.copy(init_tokens) for _ in languages]
1317
+
1318
+ if language is not None and lang_ids is not None:
1319
+ lang_ids = [language_to_id(l) for l in languages]
1320
+ elif hasattr(generation_config, "lang_to_id") and is_lang_id_undefined:
1321
+ # language is not defined or intentially set to `None` to trigger language detection
1322
+ lang_ids = self.detect_language(
1323
+ input_features=input_features,
1324
+ encoder_outputs=kwargs.get("encoder_outputs", None),
1325
+ generation_config=generation_config,
1326
+ num_segment_frames=num_segment_frames,
1327
+ ).tolist()
1328
+ if lang_ids is not None:
1329
+ # append or replace lang_ids to init_tokens
1330
+ for i in range(len(init_tokens)):
1331
+ if len(init_tokens[i]) > 1:
1332
+ init_tokens[i][1] = lang_ids[i]
1333
+ else:
1334
+ init_tokens[i].append(lang_ids[i])
1335
+ del languages
1336
+
1337
+ # Update init_tokens with task
1338
+ for i in range(len(init_tokens)):
1339
+ if task is not None:
1340
+ if task in TASK_IDS:
1341
+ init_tokens[i].append(generation_config.task_to_id[generation_config.task])
1342
+ task_id = generation_config.task_to_id[generation_config.task]
1343
+
1344
+ # if task is defined it'll overwrite task ids that might have already been defined via the generation_config
1345
+ replace_or_add(init_tokens[i], task_id, generation_config.task_to_id.values())
1346
+ else:
1347
+ raise ValueError(f"The `{task}`task is not supported. The task should be one of `{TASK_IDS}`")
1348
+ elif language is not None and hasattr(generation_config, "task_to_id"):
1349
+ # if language is defined, but no task id is in `init_tokens`, default to transcribe
1350
+ if not any(ti in init_tokens[i] for ti in generation_config.task_to_id.values()):
1351
+ init_tokens[i].append(generation_config.task_to_id["transcribe"])
1352
+
1353
+ # let's make sure we don't pass `None` tokens as prompt tokens
1354
+ init_tokens[i] = [t for t in init_tokens[i] if t is not None]
1355
+
1356
+ return torch.as_tensor(init_tokens, dtype=torch.long, device=self.device).expand(batch_size, -1)
1357
+
1358
+ def detect_language(
1359
+ self,
1360
+ input_features: Optional[torch.FloatTensor] = None,
1361
+ encoder_outputs: Optional[Union[torch.FloatTensor, BaseModelOutput]] = None,
1362
+ generation_config: Optional[GenerationConfig] = None,
1363
+ num_segment_frames: int = 3000,
1364
+ ) -> torch.Tensor:
1365
+ """
1366
+ Detects language from log-mel input features or encoder_outputs
1367
+
1368
+ Parameters:
1369
+ input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
1370
+ Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
1371
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
1372
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
1373
+ [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
1374
+ tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.
1375
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
1376
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
1377
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
1378
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
1379
+ generation_config (`~generation.GenerationConfig`, *optional*):
1380
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
1381
+ passed to generate matching the attributes of `generation_config` will override them. If
1382
+ `generation_config` is not provided, the default will be used, which had the following loading
1383
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
1384
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
1385
+ default values, whose documentation should be checked to parameterize generation.
1386
+ num_segment_frames (`int`, defaults to 3000):
1387
+ The number of log-mel frames the model expects
1388
+
1389
+ Return:
1390
+ A `torch.LongTensor` representing the detected language ids.
1391
+ """
1392
+ if input_features is None and encoder_outputs is None:
1393
+ raise ValueError("You have to specify either `input_features` or `encoder_outputs`")
1394
+ elif input_features is not None and encoder_outputs is not None:
1395
+ raise ValueError("Make sure to specificy only one of `input_features` or `encoder_outputs` - not both!")
1396
+ elif input_features is not None:
1397
+ inputs = {"input_features": input_features[:, :, :num_segment_frames]}
1398
+ batch_size = input_features.shape[0]
1399
+ elif encoder_outputs is not None:
1400
+ inputs = {"encoder_outputs": encoder_outputs}
1401
+ batch_size = (
1402
+ encoder_outputs[0].shape[0] if isinstance(encoder_outputs, BaseModelOutput) else encoder_outputs[0]
1403
+ )
1404
+
1405
+ generation_config = generation_config or self.generation_config
1406
+ decoder_input_ids = (
1407
+ torch.ones((batch_size, 1), device=self.device, dtype=torch.long)
1408
+ * generation_config.decoder_start_token_id
1409
+ )
1410
+
1411
+ with torch.no_grad():
1412
+ logits = self(**inputs, decoder_input_ids=decoder_input_ids,
1413
+ stno_mask=self.stno_mask_seek if self.stno_mask_seek is not None else self.stno_mask[:, :,
1414
+ :num_segment_frames // 2]).logits[
1415
+ :, -1]
1416
+
1417
+ non_lang_mask = torch.ones_like(logits[0], dtype=torch.bool)
1418
+ non_lang_mask[list(generation_config.lang_to_id.values())] = False
1419
+
1420
+ logits[:, non_lang_mask] = -np.inf
1421
+
1422
+ lang_ids = logits.argmax(-1)
1423
+
1424
+ return lang_ids
1425
+
1426
+ def _get_logits_processor(
1427
+ self,
1428
+ generation_config: GenerationConfig,
1429
+ input_ids_seq_length: int,
1430
+ encoder_input_ids: torch.LongTensor,
1431
+ prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
1432
+ logits_processor: Optional[LogitsProcessorList],
1433
+ device: str = None,
1434
+ model_kwargs: Optional[Dict[str, Any]] = None,
1435
+ negative_prompt_ids: Optional[torch.Tensor] = None,
1436
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
1437
+ ) -> LogitsProcessorList:
1438
+ # pylint: disable=no-member
1439
+ gen_config_copy = copy.deepcopy(generation_config)
1440
+ gen_config_copy.forced_decoder_ids = None
1441
+ processors = super()._get_logits_processor(
1442
+ gen_config_copy,
1443
+ input_ids_seq_length,
1444
+ encoder_input_ids,
1445
+ prefix_allowed_tokens_fn,
1446
+ logits_processor,
1447
+ device,
1448
+ model_kwargs,
1449
+ negative_prompt_ids,
1450
+ negative_prompt_attention_mask,
1451
+ )
1452
+ if hasattr(generation_config, "ctc_weight") and generation_config.ctc_weight > 0:
1453
+ enc_logits = self.encoder_logits
1454
+ if generation_config.num_beams <= 1:
1455
+ processors.append(LogSoftmaxProcessor())
1456
+ else:
1457
+ enc_logits = enc_logits.repeat_interleave(generation_config.num_beams, dim=0)
1458
+ self.ctc_rescorer = CTCRescorerLogitsProcessor(
1459
+ enc_logits,
1460
+ torch.full((enc_logits.shape[0],), fill_value=enc_logits.shape[1],
1461
+ device=enc_logits.device),
1462
+ enc_logits.shape[-1] - 1,
1463
+ generation_config.pad_token_id.item(),
1464
+ generation_config.eos_token_id.item(),
1465
+ generation_config.decoder_start_token_id.item(),
1466
+ self.tokenizer,
1467
+ generation_config.ctc_margin,
1468
+ generation_config.ctc_weight,
1469
+ generation_config.num_beams,
1470
+ False,
1471
+ )
1472
+ processors.append(self.ctc_rescorer)
1473
+ return processors
1474
+
1475
+ def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, is_shortform, num_beams, device):
1476
+ if generation_config.return_timestamps is True:
1477
+ timestamp_processor = WhisperTimeStampLogitsProcessorCustom(generation_config, begin_index=begin_index)
1478
+ logits_processor = (
1479
+ [timestamp_processor] if logits_processor is None else [timestamp_processor] + logits_processor
1480
+ )
1481
+
1482
+ if generation_config.suppress_tokens is not None:
1483
+ suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens, device=device)
1484
+ logits_processor = (
1485
+ [suppress_tokens_processor]
1486
+ if logits_processor is None
1487
+ else [suppress_tokens_processor] + logits_processor
1488
+ )
1489
+ generation_config.suppress_tokens = None
1490
+
1491
+ if generation_config.begin_suppress_tokens is not None:
1492
+ begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor(
1493
+ generation_config.begin_suppress_tokens, begin_index=begin_index, device=device
1494
+ )
1495
+ logits_processor = (
1496
+ [begin_suppress_processor]
1497
+ if logits_processor is None
1498
+ else [begin_suppress_processor] + logits_processor
1499
+ )
1500
+ generation_config.begin_suppress_tokens = None
1501
+
1502
+ if generation_config.no_speech_threshold is not None and not is_shortform:
1503
+ no_speech_detector = WhisperNoSpeechDetection(
1504
+ no_speech_token=generation_config.no_timestamps_token_id - 1,
1505
+ begin_index=begin_index,
1506
+ scores_is_logprobs=num_beams > 1,
1507
+ )
1508
+ logits_processor = (
1509
+ [no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor
1510
+ )
1511
+ no_speech_detector.set_model(self)
1512
+
1513
+ return logits_processor
1514
+
1515
+ @staticmethod
1516
+ def round_to_nearest_0_02(x):
1517
+ d = Decimal(str(x)) # Use str(x) to preserve input precision
1518
+ step = Decimal('0.02')
1519
+ # Divide, round, multiply back
1520
+ rounded = (d / step).to_integral_value(rounding=ROUND_HALF_UP) * step
1521
+ return rounded
1522
+
1523
+ def _fix_timestamps_from_segmentation(self, sequences):
1524
+ """
1525
+ Adjusts token sequences with global timestamps to fit within Whisper's 0–30s timestamp token range.
1526
+
1527
+ This function modifies the input sequences by inserting appropriate timestamp tokens and
1528
+ offset corrections to ensure the decoded token order is correct, without splitting any segment.
1529
+ It aligns all timestamps to 0.02-second precision, inserts placeholder segments to bridge
1530
+ time gaps between 30-second windows, and maintains segment continuity during encoding.
1531
+
1532
+ Args:
1533
+ sequences (dict): A dictionary containing:
1534
+ - 'segments': A list of segment lists, each segment being a dict with 'start', 'end', and 'tokens'.
1535
+ - 'sequences': A tensor used to determine device for padding.
1536
+
1537
+ Returns:
1538
+ torch.Tensor: A batch of padded token sequences with corrected timestamp alignment.
1539
+ """
1540
+ # Get the token ID for the "<|0.00|>" timestamp used to detect dummy segments
1541
+ first_timestamp_token = self.tokenizer.get_vocab()["<|0.00|>"]
1542
+ results = []
1543
+
1544
+ # Filter out segments that are either empty or consist only of the "<|0.00|>" token
1545
+ for idx, sequence_segs in enumerate(sequences['segments']):
1546
+ sequences['segments'][idx] = [
1547
+ seg for seg in sequence_segs
1548
+ if len(seg['tokens']) > 0 and (len(seg['tokens']) != 1 or seg['tokens'][0] != first_timestamp_token)
1549
+ ]
1550
+
1551
+ # Iterate over each group of segments (e.g., one per utterance)
1552
+ for idx, sequence_segs in enumerate(sequences['segments']):
1553
+ result = []
1554
+ prev_segment_end_time = None
1555
+ correction = Decimal(0.0)
1556
+
1557
+ for i, seg in enumerate(sequence_segs):
1558
+ # Round start and end times to nearest 0.02 seconds
1559
+ start_time = self.round_to_nearest_0_02(seg['start'].item())
1560
+ end_time = self.round_to_nearest_0_02(seg['end'].item())
1561
+ tokens = seg['tokens']
1562
+
1563
+ # Determine which 30s window this segment falls into
1564
+ current_block = (start_time + correction) // 30
1565
+
1566
+ if prev_segment_end_time is not None:
1567
+ # If not the first segment, calculate difference in 30s windows
1568
+ prev_block = prev_segment_end_time // 30
1569
+ num_dummies = current_block - prev_block - 1
1570
+
1571
+ # Insert (30, [], 30) marker if we're moving to a new block
1572
+ if current_block > prev_block:
1573
+ result.append((30, [], 30))
1574
+
1575
+ # Insert dummy segments to bridge skipped 30s blocks
1576
+ for _ in range(int(num_dummies)):
1577
+ result.append((0, [], 30))
1578
+ else:
1579
+ # For the first segment, add dummy blocks if it starts after 30s
1580
+ for _ in range(int(start_time // 30)):
1581
+ result.append((0, [], 30))
1582
+
1583
+ # Determine whether segment fits in one block or wraps to the next
1584
+ if (start_time + correction) // 30 == (end_time + correction) // 30:
1585
+ # Segment fits within a single 30s window
1586
+ result.append(((start_time + correction) % 30, tokens, (end_time + correction) % 30))
1587
+ else:
1588
+ # Segment would wrap across a 30s boundary
1589
+ new_seg_start = (correction + start_time) % 30
1590
+ new_seg_end = end_time - start_time
1591
+
1592
+ if new_seg_end >= new_seg_start:
1593
+ # Seek back to the beginning of the segment window
1594
+ result.append((new_seg_start, [], new_seg_start))
1595
+ result.append((0, tokens, new_seg_end))
1596
+ # Apply correction to align future timestamps to new 30s block
1597
+ correction = self.round_to_nearest_0_02(-(start_time % 30))
1598
+ else:
1599
+ # Otherwise, just insert with adjusted times
1600
+ result.append((new_seg_start, tokens, new_seg_end))
1601
+ correction = self.round_to_nearest_0_02(30 - (start_time % 30))
1602
+ # print(f'Processed segment {i}, result: {self.tokenizer.decode(self.tokenizer("".join([f"<|{seg[0]:.2f}|>{self.tokenizer.decode(seg[1])}<|{seg[2]:.2f}|>" for seg in result]))["input_ids"], decode_with_timestamps=True)[-250:]}')
1603
+ # Update the previous segment's end time for next iteration
1604
+ prev_segment_end_time = end_time + correction
1605
+
1606
+ # Convert result segments into a token sequence with proper timestamp formatting
1607
+ encoded = self.tokenizer(
1608
+ "".join([f"<|{seg[0]:.2f}|>{self.tokenizer.decode(seg[1])}<|{seg[2]:.2f}|>" for seg in result])
1609
+ )['input_ids']
1610
+ results.append(encoded)
1611
+
1612
+ # Pad all sequences to the same length for batching
1613
+ sequences = pad_sequence(
1614
+ [torch.tensor(res, device=sequences['sequences'].device) for res in results],
1615
+ batch_first=True,
1616
+ padding_value=self.tokenizer.pad_token_id
1617
+ )
1618
+ return sequences
1619
+
1620
+ @staticmethod
1621
+ def _retrieve_segment(
1622
+ seek_sequence,
1623
+ seek_outputs,
1624
+ time_offset,
1625
+ timestamp_begin,
1626
+ seek_num_frames,
1627
+ time_precision,
1628
+ input_stride,
1629
+ prev_idx,
1630
+ idx,
1631
+ return_token_timestamps,
1632
+ ):
1633
+ # find the predicted "end of segment" predictions of Whisper
1634
+ # "end of segment" predictions occur whenever Whisper predicts a timestamp token
1635
+ timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
1636
+ single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
1637
+ timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
1638
+ timestamp_segment_indices.add_(1)
1639
+ token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
1640
+
1641
+ # If whisper predicted a "end of segment" via a timestep token, let's go ever each
1642
+ # "end of segment" prediction and slice the decoding into segments accordingly
1643
+ if len(timestamp_segment_indices) > 0:
1644
+ # if the output contains two consecutive timestamp tokens
1645
+ slices = timestamp_segment_indices.tolist()
1646
+ segments = []
1647
+ if single_timestamp_ending:
1648
+ slices.append(len(seek_sequence))
1649
+
1650
+ last_slice = 0
1651
+ # Add each segment to list of all segments
1652
+ for current_slice in slices:
1653
+ sliced_tokens = seek_sequence[last_slice:current_slice]
1654
+ start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
1655
+ end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin
1656
+ segments.append(
1657
+ {
1658
+ "start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
1659
+ "end": time_offset[prev_idx] + end_timestamp_pos * time_precision,
1660
+ "tokens": sliced_tokens,
1661
+ "result": seek_outputs[idx],
1662
+ }
1663
+ )
1664
+ if return_token_timestamps:
1665
+ segments[-1]["token_timestamps"] = (
1666
+ token_timestamps[last_slice:current_slice] + time_offset[prev_idx]
1667
+ )
1668
+ last_slice = current_slice
1669
+
1670
+ if single_timestamp_ending:
1671
+ # single timestamp at the end means no speech after the last timestamp.
1672
+ segment_offset = seek_num_frames[prev_idx]
1673
+ else:
1674
+ # otherwise, ignore the unfinished segment and seek to the last timestamp
1675
+ # here we throw away all predictions after the last predicted "end of segment"
1676
+ # since we are cutting right in the middle of an audio
1677
+ last_timestamp_pos = seek_sequence[last_slice - 1].item() - timestamp_begin
1678
+ segment_offset = last_timestamp_pos * input_stride
1679
+ else:
1680
+ # If whisper does not predict any "end of segment" token, then
1681
+ # the whole decoding is considered a segment and we add it to the list of segments
1682
+ timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
1683
+ start_timestamp_pos = 0.0
1684
+ last_timestamp_pos = seek_num_frames[prev_idx] // 2
1685
+ skip = False
1686
+ segment_offset = seek_num_frames[prev_idx]
1687
+
1688
+ if timestamps.numel() > 1:
1689
+ start_timestamp_pos = timestamps[-2].item() - timestamp_begin
1690
+ last_timestamp_pos = timestamps[-1].item() - timestamp_begin
1691
+ elif timestamps.numel() == 1:
1692
+ # no consecutive timestamps but it has a timestamp; use the last one.
1693
+ start_timestamp_pos = timestamps[-1].item() - timestamp_begin
1694
+ if start_timestamp_pos > 200:
1695
+ # segment does not fit into decoding window, so we need to rollback
1696
+ segment_offset = start_timestamp_pos * input_stride - 100 # timestamp might be inaccurate
1697
+ skip = True
1698
+ else:
1699
+ # empty sequence, or sequence w/o timestamps
1700
+ skip = True
1701
+
1702
+ if skip:
1703
+ segments = []
1704
+ else:
1705
+ segments = [
1706
+ {
1707
+ "start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
1708
+ "end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
1709
+ "tokens": seek_sequence,
1710
+ "result": seek_outputs[idx],
1711
+ }
1712
+ ]
1713
+ if return_token_timestamps:
1714
+ segments[-1]["token_timestamps"] = token_timestamps + time_offset[prev_idx]
1715
+ segment_offset = seek_num_frames[prev_idx]
1716
+
1717
+ if segment_offset <= 0:
1718
+ msg = f"Timestamps: {timestamps}, Segments: {segments}"
1719
+ raise ValueError(f"Segment offset: {segment_offset} <= 0. This should not happen!\n{msg}")
1720
+
1721
+ return segments, segment_offset
1722
+
1723
+ def _postprocess_outputs(self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config):
1724
+ # remove all previously passed decoder input ids
1725
+ if isinstance(seek_outputs, torch.Tensor):
1726
+ seek_outputs = seek_outputs[:, decoder_input_ids.shape[-1]:]
1727
+ seek_outputs = torch.hstack((
1728
+ seek_outputs,
1729
+ torch.full((seek_outputs.shape[0], 1),
1730
+ fill_value=generation_config.pad_token_id,
1731
+ dtype=seek_outputs.dtype,
1732
+ device=seek_outputs.device
1733
+ )
1734
+ ))
1735
+ # first_eos = (seek_outputs == generation_config.eos_token_id).int().argmax(dim=1)
1736
+ # biggest_timestamp = generation_config.no_timestamps_token_id + 1 + 30 * 50
1737
+
1738
+ # empty_transcriptions = first_eos == 0
1739
+ # seek_outputs[empty_transcriptions, 0] = generation_config.no_timestamps_token_id + 1 # 0.00 timestamp
1740
+ # seek_outputs[empty_transcriptions, 1] = biggest_timestamp # 30.00 timestamp
1741
+ # seek_outputs[empty_transcriptions, 2] = generation_config.eos_token_id # 30.00 timestamp
1742
+
1743
+ return seek_outputs, seek_outputs
1744
+
1745
+ if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
1746
+ num_frames = getattr(generation_config, "num_frames", None)
1747
+ seek_outputs["token_timestamps"] = self._extract_token_timestamps(
1748
+ seek_outputs, generation_config.alignment_heads, num_frames=num_frames
1749
+ )
1750
+ seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, decoder_input_ids.shape[-1]:]
1751
+
1752
+ seek_outputs["sequences"] = seek_outputs["sequences"][:, decoder_input_ids.shape[-1]:]
1753
+
1754
+ def split_by_batch_index(values, key, batch_idx):
1755
+ if key == "scores":
1756
+ return [v[batch_idx].cpu() for v in values]
1757
+ elif key == "past_key_values":
1758
+ # we don't save `past_key_values` as this is too costly
1759
+ return None
1760
+ elif isinstance(values[batch_idx], tuple) and torch.is_tensor(values[batch_idx][0]):
1761
+ return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
1762
+ return values[batch_idx].cpu()
1763
+
1764
+ sequence_tokens = seek_outputs["sequences"]
1765
+ seek_outputs = [
1766
+ {k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()}
1767
+ for i in range(sequence_tokens.shape[0])
1768
+ ]
1769
+
1770
+ return sequence_tokens, seek_outputs
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "begin_suppress_tokens": [
4
+ 220,
5
+ 50256
6
+ ],
7
+ "bos_token_id": 50257,
8
+ "decoder_start_token_id": 50258,
9
+ "eos_token_id": 50257,
10
+ "pad_token_id": 50257,
11
+ "transformers_version": "4.42.0"
12
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc3ff21a41ebdb9dbe637815740c4edcf77bfbfe962c601ca33071340fd77bd9
3
+ size 3833628952
modeling_dicow.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import CrossEntropyLoss
6
+ import torch.utils.checkpoint
7
+ import torch.utils.checkpoint
8
+ from transformers.modeling_outputs import Seq2SeqLMOutput
9
+ from transformers.models.speech_encoder_decoder.modeling_speech_encoder_decoder import (
10
+ shift_tokens_right,
11
+ )
12
+ from transformers.models.whisper.modeling_whisper import (
13
+ WhisperEncoder,
14
+ )
15
+ from transformers.models.whisper.modeling_whisper import (
16
+ WhisperForConditionalGeneration,
17
+ shift_tokens_right,
18
+ WhisperModel,
19
+ )
20
+ from transformers.models.whisper.modeling_whisper import sinusoids
21
+ from transformers.utils import logging
22
+
23
+ from .config import Seq2SeqLMOutputLosses, Seq2SeqModelOutputLogit, DiCoWConfig
24
+ from .encoder import CustomLinear, CustomDiagonalLinear, FDDT, DiCoWEncoder
25
+ from .generation import DiCoWGenerationMixin
26
+
27
+ logging.set_verbosity_debug()
28
+ logger = logging.get_logger("transformers")
29
+
30
+
31
+ class DiCoW(WhisperModel):
32
+ def __init__(self, config: DiCoWConfig):
33
+ super().__init__(config)
34
+ self.encoder = DiCoWEncoder(config)
35
+
36
+ def forward(
37
+ self,
38
+ input_features: Optional[torch.FloatTensor] = None,
39
+ attention_mask: Optional[torch.LongTensor] = None,
40
+ decoder_input_ids: Optional[torch.LongTensor] = None,
41
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
42
+ head_mask: Optional[torch.Tensor] = None,
43
+ decoder_head_mask: Optional[torch.Tensor] = None,
44
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
45
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
46
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
47
+ decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
48
+ decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
49
+ use_cache: Optional[bool] = None,
50
+ output_attentions: Optional[bool] = None,
51
+ output_hidden_states: Optional[bool] = None,
52
+ return_dict: Optional[bool] = None,
53
+ stno_mask: Optional[torch.FloatTensor] = None,
54
+ per_group_sizes: Optional[torch.LongTensor] = None,
55
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutputLosses]:
56
+ r"""
57
+ Returns:
58
+
59
+ Example:
60
+ ```python
61
+ >>> import torch
62
+ >>> from transformers import AutoFeatureExtractor, WhisperModel
63
+ >>> from datasets import load_dataset
64
+
65
+ >>> model = WhisperModel.from_pretrained("openai/whisper-base")
66
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
67
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
68
+ >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
69
+ >>> input_features = inputs.input_features
70
+ >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
71
+ >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
72
+ >>> list(last_hidden_state.shape)
73
+ [1, 2, 512]
74
+ ```"""
75
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
76
+ output_hidden_states = (
77
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
78
+ )
79
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
80
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
81
+
82
+ if encoder_outputs is None:
83
+ input_features = self._mask_input_features(input_features, attention_mask=attention_mask)
84
+
85
+ encoder_outputs = self.encoder(
86
+ input_features,
87
+ output_attentions=output_attentions,
88
+ output_hidden_states=True,
89
+ head_mask=head_mask,
90
+ return_dict=return_dict,
91
+ stno_mask=stno_mask,
92
+ per_group_sizes=per_group_sizes
93
+ )
94
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
95
+ # elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
96
+ # raise ValueError("encoder_outputs should be of type BaseModelOutput when return_dict=True.")
97
+
98
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
99
+ decoder_outputs = self.decoder(
100
+ input_ids=decoder_input_ids,
101
+ attention_mask=decoder_attention_mask,
102
+ encoder_hidden_states=encoder_outputs.hidden_states[-1],
103
+ head_mask=decoder_head_mask,
104
+ cross_attn_head_mask=cross_attn_head_mask,
105
+ past_key_values=past_key_values,
106
+ inputs_embeds=decoder_inputs_embeds,
107
+ position_ids=decoder_position_ids,
108
+ use_cache=use_cache,
109
+ output_attentions=output_attentions,
110
+ output_hidden_states=output_hidden_states,
111
+ return_dict=return_dict,
112
+ )
113
+
114
+ if not return_dict:
115
+ return decoder_outputs + encoder_outputs
116
+
117
+ return Seq2SeqModelOutputLogit(
118
+ last_hidden_state=decoder_outputs.last_hidden_state,
119
+ past_key_values=decoder_outputs.past_key_values,
120
+ decoder_hidden_states=decoder_outputs.hidden_states,
121
+ decoder_attentions=decoder_outputs.attentions,
122
+ cross_attentions=decoder_outputs.cross_attentions,
123
+ encoder_last_hidden_state=encoder_outputs.hidden_states[-1],
124
+ encoder_hidden_states=encoder_outputs.hidden_states,
125
+ encoder_attentions=encoder_outputs.attentions,
126
+ encoder_logits=encoder_outputs.logits,
127
+ )
128
+
129
+
130
+ class DiCoWForConditionalGeneration(DiCoWGenerationMixin, WhisperForConditionalGeneration):
131
+ config_class = DiCoWConfig
132
+
133
+ def __init__(self, config: DiCoWConfig):
134
+ super().__init__(config)
135
+ self.model = DiCoW(config)
136
+ self.encoder_logits = None
137
+ self.tokenizer = None
138
+ self.vad_seek_callback = None
139
+ self.stno_mask = None
140
+ self.stno_mask_seek = None
141
+
142
+ # We need this setter as we can't pass a function/method as a config argument.
143
+ # JSON serialization fails at that point.
144
+ def set_vad_seek_callback(self, vad_seek_callback):
145
+ self.vad_seek_callback = vad_seek_callback
146
+
147
+ def set_tokenizer(self, tokenizer):
148
+ self.tokenizer = tokenizer
149
+
150
+ def _init_weights(self, module):
151
+ std = self.config.init_std
152
+ fddt_init = self.config.fddt_init
153
+ if isinstance(module, CustomLinear):
154
+ with torch.no_grad():
155
+ if fddt_init == 'random':
156
+ module.weight.data.normal_(mean=0.0, std=std)
157
+ if module.bias is not None:
158
+ module.bias.data.normal_(mean=0.0, std=std)
159
+ elif fddt_init == 'non-disturbing':
160
+ module.weight.data = torch.eye(*module.weight.shape).data
161
+ if module.bias is not None:
162
+ module.bias.data.zero_()
163
+ elif fddt_init == 'disparagement':
164
+ eye = torch.eye(*module.weight.shape)
165
+ eye *= module.init_eye_val
166
+ module.weight.data = eye.data
167
+ if module.bias is not None:
168
+ module.bias.data.zero_()
169
+ elif isinstance(module, CustomDiagonalLinear):
170
+ with torch.no_grad():
171
+ if fddt_init == 'random':
172
+ module.weight.data.normal_(mean=0.0, std=std)
173
+ if module.bias is not None:
174
+ module.bias.data.normal_(mean=0.0, std=std)
175
+ elif fddt_init == 'non-disturbing':
176
+ module.weight.data = torch.ones_like(module.weight.data).data
177
+ if module.bias is not None:
178
+ module.bias.data.zero_()
179
+ elif fddt_init == 'disparagement':
180
+ module.weight.data = module.init_eye_val * torch.ones_like(module.weight.data).data
181
+ if module.bias is not None:
182
+ module.bias.data.zero_()
183
+ elif isinstance(module, FDDT):
184
+ if module.bias_only:
185
+ if fddt_init == 'random':
186
+ module.target_linear.data.normal_(mean=0.0, std=std)
187
+ module.non_target_linear.data.normal_(mean=0.0, std=std)
188
+ module.overlap_linear.data.normal_(mean=0.0, std=std)
189
+ module.silence_linear.data.normal_(mean=0.0, std=std)
190
+ else:
191
+ module.target_linear.data.zero_()
192
+ module.non_target_linear.data.zero_()
193
+ module.overlap_linear.data.zero_()
194
+ module.silence_linear.data.zero_()
195
+ elif isinstance(module, (nn.Linear, nn.Conv1d)):
196
+ module.weight.data.normal_(mean=0.0, std=std)
197
+ if module.bias is not None:
198
+ module.bias.data.zero_()
199
+ elif isinstance(module, nn.Embedding):
200
+ module.weight.data.normal_(mean=0.0, std=std)
201
+ if module.padding_idx is not None:
202
+ module.weight.data[module.padding_idx].zero_()
203
+ elif isinstance(module, WhisperEncoder):
204
+ with torch.no_grad():
205
+ embed_positions = module.embed_positions.weight
206
+ embed_positions.copy_(sinusoids(*embed_positions.shape))
207
+ elif isinstance(module, nn.LayerNorm):
208
+ module.reset_parameters()
209
+ elif isinstance(module, nn.MultiheadAttention):
210
+ module._reset_parameters()
211
+ elif isinstance(module, nn.ConvTranspose1d):
212
+ module.reset_parameters()
213
+
214
+ def forward(
215
+ self,
216
+ input_features: Optional[torch.FloatTensor] = None,
217
+ stno_mask: Optional[torch.FloatTensor] = None,
218
+ per_group_sizes: Optional[torch.LongTensor] = None,
219
+ attention_mask_enc: Optional[torch.LongTensor] = None,
220
+ attention_mask: Optional[torch.LongTensor] = None,
221
+ decoder_input_ids: Optional[torch.LongTensor] = None,
222
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
223
+ head_mask: Optional[torch.Tensor] = None,
224
+ decoder_head_mask: Optional[torch.Tensor] = None,
225
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
226
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
227
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
228
+ decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
229
+ decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
230
+ labels: Optional[torch.LongTensor] = None,
231
+ upp_labels: Optional[torch.LongTensor] = None,
232
+ use_cache: Optional[bool] = None,
233
+ output_attentions: Optional[bool] = None,
234
+ output_hidden_states: Optional[bool] = None,
235
+ return_dict: Optional[bool] = None,
236
+ is_valid: Optional[bool] = None,
237
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
238
+ r"""
239
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
240
+ Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
241
+ or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
242
+ only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
243
+
244
+ Returns:
245
+
246
+ Example:
247
+
248
+ ```python
249
+ >>> import torch
250
+ >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
251
+ >>> from datasets import load_dataset
252
+
253
+ >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
254
+ >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
255
+
256
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
257
+
258
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
259
+ >>> input_features = inputs.input_features
260
+
261
+ >>> generated_ids = model.generate(inputs=input_features)
262
+
263
+ >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
264
+ >>> transcription
265
+ ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
266
+ ```"""
267
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
268
+
269
+ if labels is not None:
270
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
271
+ decoder_input_ids = shift_tokens_right(
272
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
273
+ )
274
+
275
+ outputs = self.model(
276
+ input_features,
277
+ attention_mask=attention_mask,
278
+ decoder_input_ids=decoder_input_ids,
279
+ encoder_outputs=encoder_outputs,
280
+ decoder_attention_mask=decoder_attention_mask,
281
+ head_mask=head_mask,
282
+ decoder_head_mask=decoder_head_mask,
283
+ cross_attn_head_mask=cross_attn_head_mask,
284
+ past_key_values=past_key_values,
285
+ decoder_inputs_embeds=decoder_inputs_embeds,
286
+ decoder_position_ids=decoder_position_ids,
287
+ use_cache=use_cache,
288
+ output_attentions=output_attentions,
289
+ output_hidden_states=output_hidden_states,
290
+ return_dict=return_dict,
291
+ stno_mask=stno_mask,
292
+ per_group_sizes=per_group_sizes
293
+ )
294
+
295
+ dec_lm_logits = self.proj_out(outputs.last_hidden_state)
296
+ enc_lm_logits = outputs.encoder_logits
297
+
298
+ loss = None
299
+ ctc_loss = 0
300
+
301
+ # remove fake inputs from labels and logits given per group sizes
302
+ if is_valid is not None:
303
+ if self.config.ctc_weight > 0.0:
304
+ enc_lm_logits = enc_lm_logits[is_valid]
305
+ dec_lm_logits = dec_lm_logits[is_valid]
306
+ labels = labels[is_valid]
307
+ upp_labels = upp_labels[is_valid]
308
+
309
+ if labels is not None and self.config.ctc_weight > 0.0:
310
+ enc_labels = labels.clone()
311
+ for token in self.tokenizer.prefix_tokens:
312
+ if (enc_labels[:, 0] == token).all():
313
+ enc_labels = enc_labels[:, 1:]
314
+ enc_labels[enc_labels == self.config.eos_token_id] = -100
315
+
316
+ ctc_loss = self.get_encoder().get_loss(enc_lm_logits, enc_labels)
317
+
318
+ if labels is not None:
319
+ loss_fct = CrossEntropyLoss(reduction='none')
320
+ # move labels to correct device to enable PP
321
+ labels = labels.to(dec_lm_logits.device)
322
+ dec_loss1 = loss_fct(dec_lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))
323
+ dec_loss2 = loss_fct(dec_lm_logits.view(-1, self.config.vocab_size), upp_labels.reshape(-1))
324
+ dec_loss = torch.hstack((dec_loss1[..., None], dec_loss2[..., None])).min(dim=-1).values.mean()
325
+ loss = (1 - self.config.ctc_weight) * dec_loss + self.config.ctc_weight * ctc_loss
326
+
327
+ if not return_dict:
328
+ output = (dec_lm_logits,) + outputs[1:]
329
+ return ((loss,) + output) if loss is not None else output
330
+
331
+ return Seq2SeqLMOutputLosses(
332
+ loss=loss,
333
+ logits=dec_lm_logits,
334
+ past_key_values=outputs.past_key_values,
335
+ decoder_hidden_states=outputs.decoder_hidden_states,
336
+ decoder_attentions=outputs.decoder_attentions,
337
+ cross_attentions=outputs.cross_attentions,
338
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
339
+ encoder_hidden_states=outputs.encoder_hidden_states,
340
+ encoder_attentions=outputs.encoder_attentions,
341
+ encoder_logits=enc_lm_logits,
342
+ )
343
+
344
+ def _get_feat_extract_output_lengths(self, attention_mask: torch.Tensor) -> torch.Tensor:
345
+ return (self.model.encoder._get_feat_extract_output_lengths(attention_mask) / 4).ceil()
346
+
347
+ def freeze_except(self, prefixes_to_preheat):
348
+ for name, param in self.named_parameters():
349
+ param.requires_grad = False
350
+ for prefix in prefixes_to_preheat:
351
+ if name.startswith(prefix):
352
+ param.requires_grad = True
353
+
354
+ def suppress_interactions(self):
355
+ """This method suppress final projection in CoAttention blocks to let the original information flow through"""
356
+ for name, param in self.named_parameters():
357
+ if "interaction" in name and "cat_proj" in name:
358
+ with torch.no_grad():
359
+ if "bias" in name:
360
+ param[:] = 0.
361
+ else:
362
+ param[:] *= 0.001
utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from transformers import WhisperTimeStampLogitsProcessor
5
+
6
+
7
+ def remove_fake_elements(inputs, per_group_sizes):
8
+ max_spks = per_group_sizes.max()
9
+ number_of_groups = per_group_sizes.shape[0]
10
+ outputs = []
11
+ inputs = inputs.view(number_of_groups, max_spks, *inputs.shape[1:])
12
+ for i, group_size in enumerate(per_group_sizes):
13
+ outputs.append(inputs[i, :group_size])
14
+ outputs = torch.cat(outputs, dim=0)
15
+ return outputs
16
+
17
+
18
+ class WhisperTimeStampLogitsProcessorCustom(WhisperTimeStampLogitsProcessor):
19
+ def __init__(
20
+ self, generate_config, begin_index: Optional[int] = None,
21
+ _detect_timestamp_from_logprob: Optional[bool] = None
22
+ ): # support for the kwargs
23
+ self.no_timestamps_token_id = generate_config.no_timestamps_token_id
24
+ self.timestamp_begin = generate_config.no_timestamps_token_id + 1
25
+ self.eos_token_id = generate_config.eos_token_id or generate_config.bos_token_id
26
+
27
+ # this variable is mostly just used for testing
28
+ self._detect_timestamp_from_logprob = (
29
+ _detect_timestamp_from_logprob
30
+ if _detect_timestamp_from_logprob is not None
31
+ else getattr(generate_config, "_detect_timestamp_from_logprob", True)
32
+ )
33
+
34
+ num_forced_ids = (
35
+ len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0
36
+ )
37
+ self.begin_index = begin_index or (num_forced_ids + 1)
38
+
39
+ self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
40
+ self.min_initial_timestamp_index = getattr(generate_config, "min_initial_timestamp_index", None)
41
+ # TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50
42
+ # self.max_initial_timestamp_index = 50
43
+
44
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
45
+ # suppress <|notimestamps|> which is handled by without_timestamps
46
+ scores_processed = scores.clone()
47
+ scores_processed[:, self.no_timestamps_token_id] = -float("inf")
48
+
49
+ # timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
50
+ for k in range(input_ids.shape[0]):
51
+ sampled_tokens = input_ids[k, self.begin_index:]
52
+ seq = list(sampled_tokens.tolist())
53
+
54
+ last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin
55
+ penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin
56
+
57
+ if last_was_timestamp:
58
+ if penultimate_was_timestamp: # has to be non-timestamp
59
+ scores_processed[k, self.timestamp_begin:] = -float("inf")
60
+ else: # cannot be normal text tokens
61
+ scores_processed[k, : self.eos_token_id] = -float("inf")
62
+
63
+ timestamps = sampled_tokens[sampled_tokens.ge(self.timestamp_begin)]
64
+ if timestamps.numel() > 0:
65
+ # `timestamps` shouldn't decrease; forbid timestamp tokens smaller than the last
66
+ # The following lines of code are copied from: https://github.com/openai/whisper/pull/914/files#r1137085090
67
+ if last_was_timestamp and not penultimate_was_timestamp:
68
+ timestamp_last = timestamps[-1]
69
+ else:
70
+ # Avoid to emit <|0.00|> again
71
+ timestamp_last = timestamps[-1] + 1
72
+
73
+ scores_processed[k, self.timestamp_begin: timestamp_last] = -float("inf")
74
+
75
+ # apply the `max_initial_timestamp` option
76
+ if input_ids.shape[1] == self.begin_index:
77
+ eos_scores = scores_processed[:, self.eos_token_id].clone()
78
+ scores_processed[:, : self.timestamp_begin] = -float("inf")
79
+ scores_processed[:, self.eos_token_id] = eos_scores
80
+
81
+ if self.max_initial_timestamp_index is not None:
82
+ last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
83
+ scores_processed[:, last_allowed + 1:] = -float("inf")
84
+ if self.min_initial_timestamp_index is not None:
85
+ first_allowed = self.timestamp_begin + self.min_initial_timestamp_index
86
+ scores_processed[:, self.timestamp_begin:first_allowed] = -float("inf")
87
+
88
+ # if sum of probability over timestamps is above any other token, sample timestamp
89
+ logprobs = torch.nn.functional.log_softmax(scores_processed.float(), dim=-1)
90
+ for k in range(input_ids.shape[0]):
91
+ timestamp_logprob = logprobs[k, self.timestamp_begin:].logsumexp(dim=-1)
92
+ max_text_token_logprob = logprobs[k, : self.timestamp_begin].max()
93
+ if timestamp_logprob > max_text_token_logprob and self._detect_timestamp_from_logprob:
94
+ scores_processed[k, : self.timestamp_begin] = -float("inf")
95
+
96
+ return scores_processed