mrprimenotes commited on
Commit
811659f
·
verified ·
1 Parent(s): 106d647

update generate method of mixin

Browse files
Files changed (1) hide show
  1. model.py +91 -313
model.py CHANGED
@@ -1801,17 +1801,61 @@ class WhisperModel(WhisperPreTrainedModel):
1801
  encoder_attentions=encoder_outputs.attentions,
1802
  )
1803
 
1804
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1805
 
1806
  # CUSTOM (patch the generation method)
1807
  class CustomWhisperGenerationMixin(WhisperGenerationMixin):
1808
  def generate(
1809
  self,
1810
  input_features: Optional[torch.Tensor] = None,
1811
- generation_config: Optional[Any] = None,
1812
  logits_processor: Optional[LogitsProcessorList] = None,
1813
- stopping_criteria: Optional[Any] = None,
1814
- prefix_allowed_tokens_fn: Optional[Any] = None,
1815
  synced_gpus: bool = False,
1816
  return_timestamps: Optional[bool] = None,
1817
  task: Optional[str] = None,
@@ -1827,11 +1871,9 @@ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
1827
  num_segment_frames: Optional[int] = None,
1828
  attention_mask: Optional[torch.Tensor] = None,
1829
  time_precision: float = 0.02,
1830
- time_precision_features: float = 0.01,
1831
  return_token_timestamps: Optional[bool] = None,
1832
  return_segments: bool = False,
1833
  return_dict_in_generate: Optional[bool] = None,
1834
- force_unique_generate_call: Optional[bool] = None,
1835
  **kwargs,
1836
  ):
1837
  # 0. deprecate old inputs
@@ -1846,7 +1888,7 @@ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
1846
  generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
1847
 
1848
  # 2. set global generate variables
1849
- input_stride = self.model.encoder.get_conv_stride()
1850
  num_segment_frames = input_stride * self.config.max_source_positions
1851
  batch_size, total_input_frames = self._retrieve_total_input_frames(
1852
  input_features=input_features, input_stride=input_stride, kwargs=kwargs
@@ -1898,21 +1940,11 @@ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
1898
  # 3. Retrieve logits processors
1899
  device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
1900
  begin_index = init_tokens.shape[1]
1901
- num_beams = kwargs.get(
1902
- "num_beams",
1903
- generation_config.num_beams
1904
- if hasattr(generation_config, "num_beams") and generation_config.num_beams is not None
1905
- else 1,
1906
- )
1907
- if "assistant_model" in kwargs:
1908
- # speculative decoding: the model should be able to return eos token
1909
- generation_config.begin_suppress_tokens = None
1910
-
1911
  logits_processor = self._retrieve_logit_processors(
1912
  generation_config=generation_config,
1913
  logits_processor=logits_processor,
1914
  begin_index=begin_index, # begin index is index of first generated decoder token
1915
- num_beams=num_beams,
1916
  device=device,
1917
  )
1918
 
@@ -1956,19 +1988,6 @@ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
1956
  batch_size=cur_bsz,
1957
  generation_config=generation_config,
1958
  )
1959
- # 5bis speculative decoding: ensure the assistant model does only one call to generate and therefore returns decoder input token ids and eos token id
1960
- # we set a flag in the generation config to force the model to make only one call to generate and return the decoder input token ids and eos token id
1961
- if "assistant_model" in kwargs:
1962
- assistant_model = kwargs["assistant_model"]
1963
- assistant_model.generation_config.force_unique_generate_call = True
1964
-
1965
- if force_unique_generate_call is None:
1966
- if hasattr(generation_config, "force_unique_generate_call"):
1967
- force_unique_generate_call = generation_config.force_unique_generate_call
1968
- elif hasattr(self.generation_config, "force_unique_generate_call"):
1969
- force_unique_generate_call = self.generation_config.force_unique_generate_call
1970
- else:
1971
- force_unique_generate_call = False
1972
 
1973
  # 6 Transcribe audio until we reach the end of all input audios
1974
  while (seek < max_frames).any():
@@ -1983,9 +2002,7 @@ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
1983
  cur_bsz=cur_bsz,
1984
  batch_idx_map=batch_idx_map,
1985
  )
1986
- time_offset = (
1987
- seek.to(torch.float32 if device.type == "mps" else torch.float64) * time_precision / input_stride
1988
- )
1989
  seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
1990
 
1991
  # 6.2 cut out next 30s segment from input features
@@ -1997,13 +2014,6 @@ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
1997
  cur_bsz=cur_bsz,
1998
  batch_idx_map=batch_idx_map,
1999
  )
2000
-
2001
- def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name):
2002
- if logits_processor is not None:
2003
- logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None)
2004
- if logit_processor:
2005
- return getattr(logit_processor, attribute_name, None)
2006
- return None
2007
 
2008
  # 6.3 prepare decoder input ids
2009
  suppress_tokens = _get_attr_from_logit_processors(
@@ -2021,7 +2031,6 @@ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
2021
  config=self.config,
2022
  device=init_tokens.device,
2023
  suppress_tokens=suppress_tokens,
2024
- timestamp_begin=timestamp_begin,
2025
  kwargs=kwargs,
2026
  )
2027
 
@@ -2082,20 +2091,18 @@ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
2082
  timestamp_begin=timestamp_begin,
2083
  seek_num_frames=seek_num_frames,
2084
  time_precision=time_precision,
2085
- time_precision_features=time_precision_features,
2086
  input_stride=input_stride,
2087
  prev_idx=prev_i,
2088
  idx=i,
2089
  return_token_timestamps=return_token_timestamps,
2090
- decoder_input_ids=decoder_input_ids,
2091
  )
2092
 
2093
- seek[prev_i] += segment_offset
2094
-
2095
  current_segments[prev_i] += segments
2096
 
2097
- if force_unique_generate_call:
2098
- break
 
 
2099
 
2100
  # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
2101
  # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
@@ -2105,154 +2112,51 @@ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
2105
  else current_segments
2106
  )
2107
 
2108
- # if return_dict_in_generate=True and we forced a unique call to generate or return_timestamps=False, meaning we are sure only one call to generate has been made,
2109
- # -> we can return a ModelOutput
2110
- # otherwise, return_dict_in_generate is applied in the 'result' of each segment in final_segments
2111
- if (
2112
- return_dict_in_generate
2113
- and generation_config.return_dict_in_generate
2114
- and (force_unique_generate_call or not return_timestamps)
2115
- ):
2116
- # only one call to generate_with_fallback, we can return a ModelOutput
2117
- outputs = self._stack_split_outputs(seek_outputs, model_output_type, self.device, kwargs)
2118
- if num_return_sequences > 1:
2119
- if hasattr(outputs, "encoder_attentions") and outputs.encoder_attentions is not None:
2120
- outputs.encoder_attentions = tuple(
2121
- outputs.encoder_attentions[i][::num_return_sequences]
2122
- for i in range(len(outputs.encoder_attentions))
2123
- )
2124
- if hasattr(outputs, "encoder_hidden_states") and outputs.encoder_hidden_states is not None:
2125
- outputs.encoder_hidden_states = tuple(
2126
- outputs.encoder_hidden_states[i][::num_return_sequences]
2127
- for i in range(len(outputs.encoder_hidden_states))
2128
- )
2129
- return outputs
2130
-
2131
- def _pad_to_max_length(
2132
- current_segments,
2133
- pad_token_id,
2134
- device,
2135
- padding_side="right",
2136
- padding="longest",
2137
- bos_token_tensor=None,
2138
- cut_off_length=None,
2139
- return_token_timestamps=False,
2140
- force_unique_generate_call=False,
2141
- ):
2142
- max_total_length = 0
2143
- sequences = []
2144
- token_timestamps_list = []
2145
-
2146
- if padding_side not in ["right", "left"]:
2147
- raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")
2148
-
2149
- if padding not in ["longest", "max_length"]:
2150
- raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}")
2151
- elif padding == "max_length" and cut_off_length is None:
2152
- raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")
2153
-
2154
- if force_unique_generate_call:
2155
- sequences_list = []
2156
- timestamps_list = []
2157
- for segments in current_segments:
2158
- result = segments[0]["result"]
2159
- sequences_list.append(result if isinstance(result, torch.Tensor) else result["sequences"])
2160
- if return_token_timestamps:
2161
- timestamps_list.append(result["token_timestamps"])
2162
-
2163
- sequences = torch.stack(sequences_list, dim=0)
2164
- if return_token_timestamps:
2165
- token_timestamps = torch.stack(timestamps_list, dim=0)
2166
- return sequences, token_timestamps
2167
- return sequences
2168
-
2169
- for current_segment_list in current_segments:
2170
- if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
2171
- sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
2172
- if return_token_timestamps:
2173
- token_timestamps = torch.cat(
2174
- [d["result"]["token_timestamps"][d["idxs"][0] : d["idxs"][1]] for d in current_segment_list],
2175
- dim=-1,
2176
- )
2177
-
2178
- if cut_off_length is not None:
2179
- sequence = sequence[-cut_off_length:]
2180
- if return_token_timestamps:
2181
- token_timestamps = token_timestamps[-cut_off_length:]
2182
-
2183
- if bos_token_tensor is not None:
2184
- sequence = torch.cat([bos_token_tensor, sequence])
2185
- if return_token_timestamps:
2186
- token_timestamps = torch.cat(
2187
- [torch.ones_like(bos_token_tensor, device=device) * 0.0, token_timestamps]
2188
- )
2189
- sequences.append(sequence)
2190
- if return_token_timestamps:
2191
- token_timestamps_list.append(token_timestamps)
2192
- max_total_length = max(max_total_length, len(sequences[-1]))
2193
- elif bos_token_tensor is not None:
2194
- sequences.append(bos_token_tensor)
2195
- if return_token_timestamps:
2196
- token_timestamps_list.append(torch.ones_like(bos_token_tensor, device=device) * 0.0)
2197
- else:
2198
- sequences.append(torch.tensor([], device=device))
2199
- if return_token_timestamps:
2200
- token_timestamps_list.append(torch.tensor([], device=device))
2201
-
2202
- max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
2203
- for i in range(len(current_segments)):
2204
- pad_length = max_total_length - len(sequences[i])
2205
- pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
2206
 
2207
- sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
2208
- if return_token_timestamps:
2209
- token_timestamps_list[i] = F.pad(
2210
- token_timestamps_list[i],
2211
- pad=pad,
2212
- value=token_timestamps_list[i][-1] if len(token_timestamps_list[i]) > 0 else 0.0,
2213
- )
2214
 
2215
- sequences = torch.stack(sequences, dim=0)
 
 
 
 
2216
 
2217
  if return_token_timestamps:
2218
- token_timestamps = torch.stack(token_timestamps_list, dim=0)
2219
- return sequences, token_timestamps
 
2220
  else:
2221
- return sequences
2222
 
2223
- padded_outputs = _pad_to_max_length(
2224
- current_segments=final_segments,
2225
- pad_token_id=generation_config.pad_token_id,
2226
- device=self.device,
2227
- padding_side="right",
2228
- return_token_timestamps=return_token_timestamps,
2229
- force_unique_generate_call=force_unique_generate_call,
2230
- )
2231
 
2232
- if return_dict_in_generate and generation_config.return_dict_in_generate:
2233
- logger.warning_once(
2234
- "You have passed `return_dict_in_generate=True` and `return_timestamps=True`, this automatically sets `return_segments=True` to access the resuls of the underlying calls to GenerationMixin's generate in the returned `segments`."
2235
- )
2236
- return_segments = True
2237
- elif not return_segments and not return_token_timestamps:
2238
- return padded_outputs
2239
-
2240
- if return_token_timestamps:
2241
- sequences, token_timestamps = padded_outputs
2242
- outputs = {
2243
- "sequences": sequences,
2244
- "token_timestamps": token_timestamps,
2245
- }
2246
- else:
2247
- sequences = padded_outputs
2248
- outputs = {
2249
- "sequences": sequences,
2250
- }
2251
 
2252
- if return_segments:
2253
- outputs["segments"] = final_segments
2254
 
2255
- return outputs
2256
 
2257
  @add_start_docstrings(
2258
  "The Whisper Model with a language modeling head. Can be used for automatic speech recognition.",
@@ -2270,132 +2174,6 @@ class CustomWhisperForConditionalGeneration(CustomWhisperGenerationMixin, Whispe
2270
 
2271
  # Initialize weights and apply final processing
2272
  self.post_init()
2273
-
2274
- def _pad_to_max_length(
2275
- current_segments,
2276
- pad_token_id,
2277
- device,
2278
- padding_side="right",
2279
- padding="longest",
2280
- bos_token_tensor=None,
2281
- cut_off_length=None,
2282
- return_token_timestamps=False,
2283
- force_unique_generate_call=False,
2284
- ):
2285
- max_total_length = 0
2286
- sequences = []
2287
- token_timestamps_list = []
2288
-
2289
- if padding_side not in ["right", "left"]:
2290
- raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")
2291
-
2292
- if padding not in ["longest", "max_length"]:
2293
- raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}")
2294
- elif padding == "max_length" and cut_off_length is None:
2295
- raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")
2296
-
2297
- if force_unique_generate_call:
2298
- sequences_list = []
2299
- timestamps_list = []
2300
- for segments in current_segments:
2301
- result = segments[0]["result"]
2302
- sequences_list.append(result if isinstance(result, torch.Tensor) else result["sequences"])
2303
- if return_token_timestamps:
2304
- timestamps_list.append(result["token_timestamps"])
2305
-
2306
- sequences = torch.stack(sequences_list, dim=0)
2307
- if return_token_timestamps:
2308
- token_timestamps = torch.stack(timestamps_list, dim=0)
2309
- return sequences, token_timestamps
2310
- return sequences
2311
-
2312
- for current_segment_list in current_segments:
2313
- if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
2314
- sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
2315
- if return_token_timestamps:
2316
- token_timestamps = torch.cat(
2317
- [d["result"]["token_timestamps"][d["idxs"][0] : d["idxs"][1]] for d in current_segment_list],
2318
- dim=-1,
2319
- )
2320
-
2321
- if cut_off_length is not None:
2322
- sequence = sequence[-cut_off_length:]
2323
- if return_token_timestamps:
2324
- token_timestamps = token_timestamps[-cut_off_length:]
2325
-
2326
- if bos_token_tensor is not None:
2327
- sequence = torch.cat([bos_token_tensor, sequence])
2328
- if return_token_timestamps:
2329
- token_timestamps = torch.cat(
2330
- [torch.ones_like(bos_token_tensor, device=device) * 0.0, token_timestamps]
2331
- )
2332
- sequences.append(sequence)
2333
- if return_token_timestamps:
2334
- token_timestamps_list.append(token_timestamps)
2335
- max_total_length = max(max_total_length, len(sequences[-1]))
2336
- elif bos_token_tensor is not None:
2337
- sequences.append(bos_token_tensor)
2338
- if return_token_timestamps:
2339
- token_timestamps_list.append(torch.ones_like(bos_token_tensor, device=device) * 0.0)
2340
- else:
2341
- sequences.append(torch.tensor([], device=device))
2342
- if return_token_timestamps:
2343
- token_timestamps_list.append(torch.tensor([], device=device))
2344
-
2345
- max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
2346
- for i in range(len(current_segments)):
2347
- pad_length = max_total_length - len(sequences[i])
2348
- pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
2349
-
2350
- sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
2351
- if return_token_timestamps:
2352
- token_timestamps_list[i] = F.pad(
2353
- token_timestamps_list[i],
2354
- pad=pad,
2355
- value=token_timestamps_list[i][-1] if len(token_timestamps_list[i]) > 0 else 0.0,
2356
- )
2357
-
2358
- sequences = torch.stack(sequences, dim=0)
2359
-
2360
- if return_token_timestamps:
2361
- token_timestamps = torch.stack(token_timestamps_list, dim=0)
2362
- return sequences, token_timestamps
2363
- else:
2364
- return sequences
2365
-
2366
- padded_outputs = _pad_to_max_length(
2367
- current_segments=final_segments,
2368
- pad_token_id=generation_config.pad_token_id,
2369
- device=self.device,
2370
- padding_side="right",
2371
- return_token_timestamps=return_token_timestamps,
2372
- force_unique_generate_call=force_unique_generate_call,
2373
- )
2374
-
2375
- if return_dict_in_generate and generation_config.return_dict_in_generate:
2376
- logger.warning_once(
2377
- "You have passed `return_dict_in_generate=True` and `return_timestamps=True`, this automatically sets `return_segments=True` to access the resuls of the underlying calls to GenerationMixin's generate in the returned `segments`."
2378
- )
2379
- return_segments = True
2380
- elif not return_segments and not return_token_timestamps:
2381
- return padded_outputs
2382
-
2383
- if return_token_timestamps:
2384
- sequences, token_timestamps = padded_outputs
2385
- outputs = {
2386
- "sequences": sequences,
2387
- "token_timestamps": token_timestamps,
2388
- }
2389
- else:
2390
- sequences = padded_outputs
2391
- outputs = {
2392
- "sequences": sequences,
2393
- }
2394
-
2395
- if return_segments:
2396
- outputs["segments"] = final_segments
2397
-
2398
- return outputs
2399
 
2400
  def get_encoder(self):
2401
  return self.model.get_encoder()
 
1801
  encoder_attentions=encoder_outputs.attentions,
1802
  )
1803
 
1804
+ def _pad_to_max_length(
1805
+ current_segments,
1806
+ pad_token_id,
1807
+ device,
1808
+ padding_side="right",
1809
+ padding="longest",
1810
+ bos_token_tensor=None,
1811
+ cut_off_length=None,
1812
+ ):
1813
+ max_total_length = 0
1814
+ sequences = []
1815
+
1816
+ if padding_side not in ["right", "left"]:
1817
+ raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")
1818
+
1819
+ if padding not in ["longest", "max_length"]:
1820
+ raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}")
1821
+ elif padding == "max_length" and cut_off_length is None:
1822
+ raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")
1823
+
1824
+ for current_segment_list in current_segments:
1825
+ if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
1826
+ sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
1827
+
1828
+ if cut_off_length is not None:
1829
+ sequence = sequence[-cut_off_length:]
1830
+
1831
+ if bos_token_tensor is not None:
1832
+ sequence = torch.cat([bos_token_tensor, sequence])
1833
+
1834
+ sequences.append(sequence)
1835
+ max_total_length = max(max_total_length, len(sequences[-1]))
1836
+ elif bos_token_tensor is not None:
1837
+ sequences.append(bos_token_tensor)
1838
+ else:
1839
+ sequences.append(torch.tensor([], device=device))
1840
+
1841
+ max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
1842
+ for i in range(len(current_segments)):
1843
+ pad_length = max_total_length - len(sequences[i])
1844
+ pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
1845
+ sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
1846
+
1847
+ sequences = torch.stack(sequences, dim=0)
1848
+ return sequences
1849
 
1850
  # CUSTOM (patch the generation method)
1851
  class CustomWhisperGenerationMixin(WhisperGenerationMixin):
1852
  def generate(
1853
  self,
1854
  input_features: Optional[torch.Tensor] = None,
1855
+ generation_config: Optional[GenerationConfig] = None,
1856
  logits_processor: Optional[LogitsProcessorList] = None,
1857
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1858
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1859
  synced_gpus: bool = False,
1860
  return_timestamps: Optional[bool] = None,
1861
  task: Optional[str] = None,
 
1871
  num_segment_frames: Optional[int] = None,
1872
  attention_mask: Optional[torch.Tensor] = None,
1873
  time_precision: float = 0.02,
 
1874
  return_token_timestamps: Optional[bool] = None,
1875
  return_segments: bool = False,
1876
  return_dict_in_generate: Optional[bool] = None,
 
1877
  **kwargs,
1878
  ):
1879
  # 0. deprecate old inputs
 
1888
  generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
1889
 
1890
  # 2. set global generate variables
1891
+ input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
1892
  num_segment_frames = input_stride * self.config.max_source_positions
1893
  batch_size, total_input_frames = self._retrieve_total_input_frames(
1894
  input_features=input_features, input_stride=input_stride, kwargs=kwargs
 
1940
  # 3. Retrieve logits processors
1941
  device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
1942
  begin_index = init_tokens.shape[1]
 
 
 
 
 
 
 
 
 
 
1943
  logits_processor = self._retrieve_logit_processors(
1944
  generation_config=generation_config,
1945
  logits_processor=logits_processor,
1946
  begin_index=begin_index, # begin index is index of first generated decoder token
1947
+ num_beams=kwargs.get("num_beams", 1),
1948
  device=device,
1949
  )
1950
 
 
1988
  batch_size=cur_bsz,
1989
  generation_config=generation_config,
1990
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
1991
 
1992
  # 6 Transcribe audio until we reach the end of all input audios
1993
  while (seek < max_frames).any():
 
2002
  cur_bsz=cur_bsz,
2003
  batch_idx_map=batch_idx_map,
2004
  )
2005
+ time_offset = seek * time_precision / input_stride
 
 
2006
  seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
2007
 
2008
  # 6.2 cut out next 30s segment from input features
 
2014
  cur_bsz=cur_bsz,
2015
  batch_idx_map=batch_idx_map,
2016
  )
 
 
 
 
 
 
 
2017
 
2018
  # 6.3 prepare decoder input ids
2019
  suppress_tokens = _get_attr_from_logit_processors(
 
2031
  config=self.config,
2032
  device=init_tokens.device,
2033
  suppress_tokens=suppress_tokens,
 
2034
  kwargs=kwargs,
2035
  )
2036
 
 
2091
  timestamp_begin=timestamp_begin,
2092
  seek_num_frames=seek_num_frames,
2093
  time_precision=time_precision,
 
2094
  input_stride=input_stride,
2095
  prev_idx=prev_i,
2096
  idx=i,
2097
  return_token_timestamps=return_token_timestamps,
 
2098
  )
2099
 
 
 
2100
  current_segments[prev_i] += segments
2101
 
2102
+ if is_shortform:
2103
+ seek[prev_i] += max_frames[i]
2104
+ else:
2105
+ seek[prev_i] += segment_offset
2106
 
2107
  # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
2108
  # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
 
2112
  else current_segments
2113
  )
2114
 
2115
+ sequences = _pad_to_max_length(
2116
+ final_segments, generation_config.pad_token_id, device=self.device, padding_side="right"
2117
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2118
 
2119
+ # 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
2120
+ if return_segments:
2121
+ return {"sequences": sequences, "segments": final_segments}
 
 
 
 
2122
 
2123
+ if is_shortform:
2124
+ # add eos token:
2125
+ if generation_config.max_new_tokens is None and generation_config.max_length is None:
2126
+ eos_tokens = torch.full((sequences.shape[0], 1), generation_config.eos_token_id)
2127
+ sequences = torch.cat([sequences, eos_tokens], dim=-1)
2128
 
2129
  if return_token_timestamps:
2130
+ outputs = {}
2131
+ outputs["sequences"] = sequences
2132
+ outputs["token_timestamps"] = torch.stack([d["token_timestamps"] for d in seek_outputs], dim=0)
2133
  else:
2134
+ outputs = sequences
2135
 
2136
+ if return_dict_in_generate and generation_config.return_dict_in_generate:
2137
+ dict_outputs = self._stack_split_outputs(seek_outputs, model_output_type, sequences.device, kwargs)
 
 
 
 
 
 
2138
 
2139
+ if num_return_sequences > 1:
2140
+ if hasattr(dict_outputs, "encoder_attentions") and dict_outputs.encoder_attentions is not None:
2141
+ dict_outputs.encoder_attentions = tuple(
2142
+ dict_outputs.encoder_attentions[i][::num_return_sequences]
2143
+ for i in range(len(dict_outputs.encoder_attentions))
2144
+ )
2145
+ if (
2146
+ hasattr(dict_outputs, "encoder_hidden_states")
2147
+ and dict_outputs.encoder_hidden_states is not None
2148
+ ):
2149
+ dict_outputs.encoder_hidden_states = tuple(
2150
+ dict_outputs.encoder_hidden_states[i][::num_return_sequences]
2151
+ for i in range(len(dict_outputs.encoder_hidden_states))
2152
+ )
2153
+ if return_token_timestamps:
2154
+ dict_outputs["token_timestamps"] = outputs["token_timestamps"]
2155
+ return dict_outputs
 
 
2156
 
2157
+ return outputs
 
2158
 
2159
+ return sequences
2160
 
2161
  @add_start_docstrings(
2162
  "The Whisper Model with a language modeling head. Can be used for automatic speech recognition.",
 
2174
 
2175
  # Initialize weights and apply final processing
2176
  self.post_init()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2177
 
2178
  def get_encoder(self):
2179
  return self.model.get_encoder()