mrprimenotes commited on
Commit
872634a
·
verified ·
1 Parent(s): 7518480

fix generate

Browse files
Files changed (1) hide show
  1. model.py +126 -0
model.py CHANGED
@@ -2130,6 +2130,132 @@ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
2130
  for i in range(len(outputs.encoder_hidden_states))
2131
  )
2132
  return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2133
 
2134
  @add_start_docstrings(
2135
  "The Whisper Model with a language modeling head. Can be used for automatic speech recognition.",
 
2130
  for i in range(len(outputs.encoder_hidden_states))
2131
  )
2132
  return outputs
2133
+
2134
+ def _pad_to_max_length(
2135
+ current_segments,
2136
+ pad_token_id,
2137
+ device,
2138
+ padding_side="right",
2139
+ padding="longest",
2140
+ bos_token_tensor=None,
2141
+ cut_off_length=None,
2142
+ return_token_timestamps=False,
2143
+ force_unique_generate_call=False,
2144
+ ):
2145
+ max_total_length = 0
2146
+ sequences = []
2147
+ token_timestamps_list = []
2148
+
2149
+ if padding_side not in ["right", "left"]:
2150
+ raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")
2151
+
2152
+ if padding not in ["longest", "max_length"]:
2153
+ raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}")
2154
+ elif padding == "max_length" and cut_off_length is None:
2155
+ raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")
2156
+
2157
+ if force_unique_generate_call:
2158
+ sequences_list = []
2159
+ timestamps_list = []
2160
+ for segments in current_segments:
2161
+ result = segments[0]["result"]
2162
+ sequences_list.append(result if isinstance(result, torch.Tensor) else result["sequences"])
2163
+ if return_token_timestamps:
2164
+ timestamps_list.append(result["token_timestamps"])
2165
+
2166
+ sequences = torch.stack(sequences_list, dim=0)
2167
+ if return_token_timestamps:
2168
+ token_timestamps = torch.stack(timestamps_list, dim=0)
2169
+ return sequences, token_timestamps
2170
+ return sequences
2171
+
2172
+ for current_segment_list in current_segments:
2173
+ if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
2174
+ sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
2175
+ if return_token_timestamps:
2176
+ token_timestamps = torch.cat(
2177
+ [d["result"]["token_timestamps"][d["idxs"][0] : d["idxs"][1]] for d in current_segment_list],
2178
+ dim=-1,
2179
+ )
2180
+
2181
+ if cut_off_length is not None:
2182
+ sequence = sequence[-cut_off_length:]
2183
+ if return_token_timestamps:
2184
+ token_timestamps = token_timestamps[-cut_off_length:]
2185
+
2186
+ if bos_token_tensor is not None:
2187
+ sequence = torch.cat([bos_token_tensor, sequence])
2188
+ if return_token_timestamps:
2189
+ token_timestamps = torch.cat(
2190
+ [torch.ones_like(bos_token_tensor, device=device) * 0.0, token_timestamps]
2191
+ )
2192
+ sequences.append(sequence)
2193
+ if return_token_timestamps:
2194
+ token_timestamps_list.append(token_timestamps)
2195
+ max_total_length = max(max_total_length, len(sequences[-1]))
2196
+ elif bos_token_tensor is not None:
2197
+ sequences.append(bos_token_tensor)
2198
+ if return_token_timestamps:
2199
+ token_timestamps_list.append(torch.ones_like(bos_token_tensor, device=device) * 0.0)
2200
+ else:
2201
+ sequences.append(torch.tensor([], device=device))
2202
+ if return_token_timestamps:
2203
+ token_timestamps_list.append(torch.tensor([], device=device))
2204
+
2205
+ max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
2206
+ for i in range(len(current_segments)):
2207
+ pad_length = max_total_length - len(sequences[i])
2208
+ pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
2209
+
2210
+ sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
2211
+ if return_token_timestamps:
2212
+ token_timestamps_list[i] = F.pad(
2213
+ token_timestamps_list[i],
2214
+ pad=pad,
2215
+ value=token_timestamps_list[i][-1] if len(token_timestamps_list[i]) > 0 else 0.0,
2216
+ )
2217
+
2218
+ sequences = torch.stack(sequences, dim=0)
2219
+
2220
+ if return_token_timestamps:
2221
+ token_timestamps = torch.stack(token_timestamps_list, dim=0)
2222
+ return sequences, token_timestamps
2223
+ else:
2224
+ return sequences
2225
+
2226
+ padded_outputs = _pad_to_max_length(
2227
+ current_segments=final_segments,
2228
+ pad_token_id=generation_config.pad_token_id,
2229
+ device=self.device,
2230
+ padding_side="right",
2231
+ return_token_timestamps=return_token_timestamps,
2232
+ force_unique_generate_call=force_unique_generate_call,
2233
+ )
2234
+
2235
+ if return_dict_in_generate and generation_config.return_dict_in_generate:
2236
+ logger.warning_once(
2237
+ "You have passed `return_dict_in_generate=True` and `return_timestamps=True`, this automatically sets `return_segments=True` to access the resuls of the underlying calls to GenerationMixin's generate in the returned `segments`."
2238
+ )
2239
+ return_segments = True
2240
+ elif not return_segments and not return_token_timestamps:
2241
+ return padded_outputs
2242
+
2243
+ if return_token_timestamps:
2244
+ sequences, token_timestamps = padded_outputs
2245
+ outputs = {
2246
+ "sequences": sequences,
2247
+ "token_timestamps": token_timestamps,
2248
+ }
2249
+ else:
2250
+ sequences = padded_outputs
2251
+ outputs = {
2252
+ "sequences": sequences,
2253
+ }
2254
+
2255
+ if return_segments:
2256
+ outputs["segments"] = final_segments
2257
+
2258
+ return outputs
2259
 
2260
  @add_start_docstrings(
2261
  "The Whisper Model with a language modeling head. Can be used for automatic speech recognition.",