christopher-hoernle commited on
Commit
51ea6ce
·
1 Parent(s): 793fd0d
Files changed (1) hide show
  1. model.py +459 -59
model.py CHANGED
@@ -1,6 +1,11 @@
1
  from transformers.models.whisper.configuration_whisper import WhisperConfig
2
- from typing import List, Literal, Optional, Dict, Any
3
- import types
 
 
 
 
 
4
 
5
  """Custom config to support modification of the Whisper encoder."""
6
 
@@ -1152,8 +1157,7 @@ class WhisperEncoder(WhisperPreTrainedModel):
1152
  )
1153
 
1154
  inputs_embeds = inputs_embeds.permute(0, 2, 1)
1155
- embed_pos = self.embed_positions.weight
1156
-
1157
  sequence_length = hidden_states.shape[1]
1158
 
1159
  # CUSTOM
@@ -1811,69 +1815,465 @@ class CustomWhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTr
1811
  self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
1812
  self.max_target_positions = config.max_target_positions
1813
 
1814
- self.patch_generate()
1815
-
1816
  # Initialize weights and apply final processing
1817
  self.post_init()
1818
 
1819
- # CUSTOM (Monkeypatch the generation method)
1820
- def patch_generate(self):
1821
- """
1822
- Monkey patches the WhisperGenerationMixin to use dynamic stride calculation
1823
- """
1824
- original_generate = WhisperGenerationMixin.generate
1825
-
1826
- def get_conv_stride(self):
1827
- """Calculate total stride of all conv layers"""
1828
- total_stride = 1
1829
- for layer in self.model.encoder.conv_layers:
1830
- total_stride *= layer.stride[0]
1831
- return total_stride
1832
 
1833
- def generate_wrapper(self, *args, **kwargs):
1834
- # Store the original function logic
1835
- original_code = original_generate.__code__
1836
-
1837
- # Create a modified version of the function that uses our stride calculation
1838
- modified_code = types.CodeType(
1839
- original_code.co_argcount,
1840
- original_code.co_posonlyargcount,
1841
- original_code.co_kwonlyargcount,
1842
- original_code.co_nlocals,
1843
- original_code.co_stacksize,
1844
- original_code.co_flags,
1845
- original_code.co_code.replace(
1846
- # Replace the hardcoded stride calculation with our dynamic one
1847
- b"self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]",
1848
- b"self.get_conv_stride()",
1849
- ),
1850
- original_code.co_consts,
1851
- original_code.co_names,
1852
- original_code.co_varnames,
1853
- original_code.co_filename,
1854
- original_code.co_name,
1855
- original_code.co_firstlineno,
1856
- original_code.co_lnotab,
1857
- original_code.co_freevars,
1858
- original_code.co_cellvars,
 
 
 
 
 
 
 
 
 
1859
  )
1860
-
1861
- # Create a new function with the modified code
1862
- new_generate = types.FunctionType(
1863
- modified_code,
1864
- original_generate.__globals__,
1865
- original_generate.__name__,
1866
- original_generate.__defaults__,
1867
- original_generate.__closure__,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1868
  )
1869
 
1870
- # Bind the function to the instance and call it
1871
- return new_generate(self, *args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1872
 
1873
- # Add the stride calculation method to the mixin
1874
- WhisperGenerationMixin.get_conv_stride = get_conv_stride
1875
- # Replace the original generate method
1876
- WhisperGenerationMixin.generate = generate_wrapper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1877
 
1878
  def get_encoder(self):
1879
  return self.model.get_encoder()
 
1
  from transformers.models.whisper.configuration_whisper import WhisperConfig
2
+ import torch.nn.functional as F
3
+ from transformers.generation.logits_process import (
4
+ LogitsProcessorList,
5
+ SuppressTokensLogitsProcessor
6
+ )
7
+ from typing import List, Optional, Dict, Any
8
+ import warnings
9
 
10
  """Custom config to support modification of the Whisper encoder."""
11
 
 
1157
  )
1158
 
1159
  inputs_embeds = inputs_embeds.permute(0, 2, 1)
1160
+
 
1161
  sequence_length = hidden_states.shape[1]
1162
 
1163
  # CUSTOM
 
1815
  self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
1816
  self.max_target_positions = config.max_target_positions
1817
 
 
 
1818
  # Initialize weights and apply final processing
1819
  self.post_init()
1820
 
1821
+ # CUSTOM (patch the generation method)
1822
+ def get_conv_stride(self):
1823
+ """Calculate total stride of all conv layers"""
1824
+ total_stride = 1
1825
+ for layer in self.model.encoder.conv_layers:
1826
+ total_stride *= layer.stride[0]
1827
+ return total_stride
 
 
 
 
 
 
1828
 
1829
+ def generate(
1830
+ self,
1831
+ input_features: Optional[torch.Tensor] = None,
1832
+ generation_config: Optional[Any] = None,
1833
+ logits_processor: Optional[LogitsProcessorList] = None,
1834
+ stopping_criteria: Optional[Any] = None,
1835
+ prefix_allowed_tokens_fn: Optional[Any] = None,
1836
+ synced_gpus: bool = False,
1837
+ return_timestamps: Optional[bool] = None,
1838
+ task: Optional[str] = None,
1839
+ language: Optional[Union[str, List[str]]] = None,
1840
+ is_multilingual: Optional[bool] = None,
1841
+ prompt_ids: Optional[torch.Tensor] = None,
1842
+ prompt_condition_type: Optional[str] = None, # first-segment, all-segments
1843
+ condition_on_prev_tokens: Optional[bool] = None,
1844
+ temperature: Optional[Union[float, Tuple[float, ...]]] = None,
1845
+ compression_ratio_threshold: Optional[float] = None,
1846
+ logprob_threshold: Optional[float] = None,
1847
+ no_speech_threshold: Optional[float] = None,
1848
+ num_segment_frames: Optional[int] = None,
1849
+ attention_mask: Optional[torch.Tensor] = None,
1850
+ time_precision: float = 0.02,
1851
+ time_precision_features: float = 0.01,
1852
+ return_token_timestamps: Optional[bool] = None,
1853
+ return_segments: bool = False,
1854
+ return_dict_in_generate: Optional[bool] = None,
1855
+ force_unique_generate_call: Optional[bool] = None,
1856
+ **kwargs,
1857
+ ):
1858
+ # 0. deprecate old inputs
1859
+ if "inputs" in kwargs:
1860
+ input_features = kwargs.pop("inputs")
1861
+ warnings.warn(
1862
+ "The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
1863
+ FutureWarning,
1864
  )
1865
+
1866
+ # 1. prepare generation config
1867
+ generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
1868
+
1869
+ # 2. set global generate variables
1870
+ input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
1871
+ num_segment_frames = input_stride * self.config.max_source_positions
1872
+ batch_size, total_input_frames = self._retrieve_total_input_frames(
1873
+ input_features=input_features, input_stride=input_stride, kwargs=kwargs
1874
+ )
1875
+ is_shortform = total_input_frames <= num_segment_frames
1876
+
1877
+ # 3. Make sure generation config is correctly set
1878
+ # Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
1879
+ return_dict_in_generate = self._set_return_outputs(
1880
+ return_dict_in_generate=return_dict_in_generate,
1881
+ return_token_timestamps=return_token_timestamps,
1882
+ logprob_threshold=logprob_threshold,
1883
+ generation_config=generation_config,
1884
+ )
1885
+ timestamp_begin = self._set_return_timestamps(
1886
+ return_timestamps=return_timestamps, is_shortform=is_shortform, generation_config=generation_config
1887
+ )
1888
+ self._set_language_and_task(
1889
+ language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
1890
+ )
1891
+ self._set_num_frames(
1892
+ return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
1893
+ )
1894
+ self._set_thresholds_and_condition(
1895
+ generation_config=generation_config,
1896
+ logprob_threshold=logprob_threshold,
1897
+ compression_ratio_threshold=compression_ratio_threshold,
1898
+ no_speech_threshold=no_speech_threshold,
1899
+ condition_on_prev_tokens=condition_on_prev_tokens,
1900
+ )
1901
+ self._set_prompt_condition_type(
1902
+ generation_config=generation_config,
1903
+ prompt_condition_type=prompt_condition_type,
1904
+ )
1905
+
1906
+ # pass self.config for backward compatibility
1907
+ init_tokens = self._retrieve_init_tokens(
1908
+ input_features,
1909
+ batch_size=batch_size,
1910
+ generation_config=generation_config,
1911
+ config=self.config,
1912
+ num_segment_frames=num_segment_frames,
1913
+ kwargs=kwargs,
1914
+ )
1915
+ # passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
1916
+ # where the input ids are handled explicitly by the generate method
1917
+ self._check_decoder_input_ids(kwargs=kwargs)
1918
+
1919
+ # 3. Retrieve logits processors
1920
+ device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
1921
+ begin_index = init_tokens.shape[1]
1922
+ num_beams = kwargs.get(
1923
+ "num_beams",
1924
+ generation_config.num_beams
1925
+ if hasattr(generation_config, "num_beams") and generation_config.num_beams is not None
1926
+ else 1,
1927
+ )
1928
+ if "assistant_model" in kwargs:
1929
+ # speculative decoding: the model should be able to return eos token
1930
+ generation_config.begin_suppress_tokens = None
1931
+
1932
+ logits_processor = self._retrieve_logit_processors(
1933
+ generation_config=generation_config,
1934
+ logits_processor=logits_processor,
1935
+ begin_index=begin_index, # begin index is index of first generated decoder token
1936
+ num_beams=num_beams,
1937
+ device=device,
1938
+ )
1939
+
1940
+ # 4 Set and retrieve global generation variables
1941
+ self._set_condition_on_prev_tokens(
1942
+ condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config
1943
+ )
1944
+
1945
+ temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature
1946
+ temperature = temperatures[0]
1947
+
1948
+ max_frames, seek = self._retrieve_max_frames_and_seek(
1949
+ batch_size=batch_size,
1950
+ attention_mask=attention_mask,
1951
+ total_input_frames=total_input_frames,
1952
+ is_shortform=is_shortform,
1953
+ )
1954
+
1955
+ # 5 Prepare running variables, list for generation
1956
+ num_return_sequences = generation_config.num_return_sequences
1957
+ (
1958
+ batch_idx_map,
1959
+ cur_bsz,
1960
+ input_features,
1961
+ seek,
1962
+ max_frames,
1963
+ init_tokens,
1964
+ do_condition_on_prev_tokens,
1965
+ ) = self._expand_variables_for_generation(
1966
+ input_features=input_features,
1967
+ seek=seek,
1968
+ max_frames=max_frames,
1969
+ init_tokens=init_tokens,
1970
+ batch_size=batch_size,
1971
+ condition_on_prev_tokens=condition_on_prev_tokens,
1972
+ generation_config=generation_config,
1973
+ )
1974
+
1975
+ current_segments = self._prepare_segments(
1976
+ prompt_ids=prompt_ids,
1977
+ batch_size=cur_bsz,
1978
+ generation_config=generation_config,
1979
+ )
1980
+ # 5bis speculative decoding: ensure the assistant model does only one call to generate and therefore returns decoder input token ids and eos token id
1981
+ # we set a flag in the generation config to force the model to make only one call to generate and return the decoder input token ids and eos token id
1982
+ if "assistant_model" in kwargs:
1983
+ assistant_model = kwargs["assistant_model"]
1984
+ assistant_model.generation_config.force_unique_generate_call = True
1985
+
1986
+ if force_unique_generate_call is None:
1987
+ if hasattr(generation_config, "force_unique_generate_call"):
1988
+ force_unique_generate_call = generation_config.force_unique_generate_call
1989
+ elif hasattr(self.generation_config, "force_unique_generate_call"):
1990
+ force_unique_generate_call = self.generation_config.force_unique_generate_call
1991
+ else:
1992
+ force_unique_generate_call = False
1993
+
1994
+ # 6 Transcribe audio until we reach the end of all input audios
1995
+ while (seek < max_frames).any():
1996
+ # 6.1 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop
1997
+ # in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order
1998
+ # to know which original audio is being decoded
1999
+ # Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk
2000
+ input_features, cur_bsz, batch_idx_map = self._maybe_reduce_batch(
2001
+ input_features=input_features,
2002
+ seek=seek,
2003
+ max_frames=max_frames,
2004
+ cur_bsz=cur_bsz,
2005
+ batch_idx_map=batch_idx_map,
2006
+ )
2007
+ time_offset = (
2008
+ seek.to(torch.float32 if device.type == "mps" else torch.float64) * time_precision / input_stride
2009
+ )
2010
+ seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
2011
+
2012
+ # 6.2 cut out next 30s segment from input features
2013
+ segment_input = self._get_input_segment(
2014
+ input_features=input_features,
2015
+ seek=seek,
2016
+ seek_num_frames=seek_num_frames,
2017
+ num_segment_frames=num_segment_frames,
2018
+ cur_bsz=cur_bsz,
2019
+ batch_idx_map=batch_idx_map,
2020
  )
2021
 
2022
+ def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name):
2023
+ if logits_processor is not None:
2024
+ logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None)
2025
+ if logit_processor:
2026
+ return getattr(logit_processor, attribute_name, None)
2027
+ return None
2028
+
2029
+ # 6.3 prepare decoder input ids
2030
+ suppress_tokens = _get_attr_from_logit_processors(
2031
+ logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens"
2032
+ )
2033
+
2034
+ decoder_input_ids, kwargs = self._prepare_decoder_input_ids(
2035
+ cur_bsz=cur_bsz,
2036
+ init_tokens=init_tokens,
2037
+ current_segments=current_segments,
2038
+ batch_idx_map=batch_idx_map,
2039
+ do_condition_on_prev_tokens=do_condition_on_prev_tokens,
2040
+ prompt_ids=prompt_ids,
2041
+ generation_config=generation_config,
2042
+ config=self.config,
2043
+ device=init_tokens.device,
2044
+ suppress_tokens=suppress_tokens,
2045
+ timestamp_begin=timestamp_begin,
2046
+ kwargs=kwargs,
2047
+ )
2048
+
2049
+ # 6.4 set max new tokens or max length
2050
+ self._set_max_new_tokens_and_length(
2051
+ config=self.config,
2052
+ decoder_input_ids=decoder_input_ids,
2053
+ generation_config=generation_config,
2054
+ )
2055
+
2056
+ # 6.5 Set current `begin_index` for all logit processors
2057
+ if logits_processor is not None:
2058
+ for proc in logits_processor:
2059
+ if hasattr(proc, "set_begin_index"):
2060
+ proc.set_begin_index(decoder_input_ids.shape[-1])
2061
+
2062
+ # 6.6 Run generate with fallback
2063
+ (
2064
+ seek_sequences,
2065
+ seek_outputs,
2066
+ should_skip,
2067
+ do_condition_on_prev_tokens,
2068
+ model_output_type,
2069
+ ) = self.generate_with_fallback(
2070
+ segment_input=segment_input,
2071
+ decoder_input_ids=decoder_input_ids,
2072
+ cur_bsz=cur_bsz,
2073
+ batch_idx_map=batch_idx_map,
2074
+ seek=seek,
2075
+ num_segment_frames=num_segment_frames,
2076
+ max_frames=max_frames,
2077
+ temperatures=temperatures,
2078
+ generation_config=generation_config,
2079
+ logits_processor=logits_processor,
2080
+ stopping_criteria=stopping_criteria,
2081
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
2082
+ synced_gpus=synced_gpus,
2083
+ return_token_timestamps=return_token_timestamps,
2084
+ do_condition_on_prev_tokens=do_condition_on_prev_tokens,
2085
+ is_shortform=is_shortform,
2086
+ batch_size=batch_size,
2087
+ attention_mask=attention_mask,
2088
+ kwargs=kwargs,
2089
+ )
2090
+
2091
+ # 6.7 In every generated sequence, split by timestamp tokens and extract segments
2092
+ for i, seek_sequence in enumerate(seek_sequences):
2093
+ prev_i = batch_idx_map[i]
2094
+
2095
+ if should_skip[i]:
2096
+ seek[prev_i] += seek_num_frames[prev_i]
2097
+ continue
2098
+
2099
+ segments, segment_offset = self._retrieve_segment(
2100
+ seek_sequence=seek_sequence,
2101
+ seek_outputs=seek_outputs,
2102
+ time_offset=time_offset,
2103
+ timestamp_begin=timestamp_begin,
2104
+ seek_num_frames=seek_num_frames,
2105
+ time_precision=time_precision,
2106
+ time_precision_features=time_precision_features,
2107
+ input_stride=input_stride,
2108
+ prev_idx=prev_i,
2109
+ idx=i,
2110
+ return_token_timestamps=return_token_timestamps,
2111
+ decoder_input_ids=decoder_input_ids,
2112
+ )
2113
+
2114
+ seek[prev_i] += segment_offset
2115
+
2116
+ current_segments[prev_i] += segments
2117
+
2118
+ if force_unique_generate_call:
2119
+ break
2120
+
2121
+ # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
2122
+ # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
2123
+ final_segments = (
2124
+ [x[1:] for x in current_segments]
2125
+ if (prompt_ids is not None and generation_config.prompt_condition_type == "first-segment")
2126
+ else current_segments
2127
+ )
2128
+
2129
+ # if return_dict_in_generate=True and we forced a unique call to generate or return_timestamps=False, meaning we are sure only one call to generate has been made,
2130
+ # -> we can return a ModelOutput
2131
+ # otherwise, return_dict_in_generate is applied in the 'result' of each segment in final_segments
2132
+ if (
2133
+ return_dict_in_generate
2134
+ and generation_config.return_dict_in_generate
2135
+ and (force_unique_generate_call or not return_timestamps)
2136
+ ):
2137
+ # only one call to generate_with_fallback, we can return a ModelOutput
2138
+ outputs = self._stack_split_outputs(seek_outputs, model_output_type, self.device, kwargs)
2139
+ if num_return_sequences > 1:
2140
+ if hasattr(outputs, "encoder_attentions") and outputs.encoder_attentions is not None:
2141
+ outputs.encoder_attentions = tuple(
2142
+ outputs.encoder_attentions[i][::num_return_sequences]
2143
+ for i in range(len(outputs.encoder_attentions))
2144
+ )
2145
+ if hasattr(outputs, "encoder_hidden_states") and outputs.encoder_hidden_states is not None:
2146
+ outputs.encoder_hidden_states = tuple(
2147
+ outputs.encoder_hidden_states[i][::num_return_sequences]
2148
+ for i in range(len(outputs.encoder_hidden_states))
2149
+ )
2150
+ return outputs
2151
 
2152
+ def _pad_to_max_length(
2153
+ current_segments,
2154
+ pad_token_id,
2155
+ device,
2156
+ padding_side="right",
2157
+ padding="longest",
2158
+ bos_token_tensor=None,
2159
+ cut_off_length=None,
2160
+ return_token_timestamps=False,
2161
+ force_unique_generate_call=False,
2162
+ ):
2163
+ max_total_length = 0
2164
+ sequences = []
2165
+ token_timestamps_list = []
2166
+
2167
+ if padding_side not in ["right", "left"]:
2168
+ raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")
2169
+
2170
+ if padding not in ["longest", "max_length"]:
2171
+ raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}")
2172
+ elif padding == "max_length" and cut_off_length is None:
2173
+ raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")
2174
+
2175
+ if force_unique_generate_call:
2176
+ sequences_list = []
2177
+ timestamps_list = []
2178
+ for segments in current_segments:
2179
+ result = segments[0]["result"]
2180
+ sequences_list.append(result if isinstance(result, torch.Tensor) else result["sequences"])
2181
+ if return_token_timestamps:
2182
+ timestamps_list.append(result["token_timestamps"])
2183
+
2184
+ sequences = torch.stack(sequences_list, dim=0)
2185
+ if return_token_timestamps:
2186
+ token_timestamps = torch.stack(timestamps_list, dim=0)
2187
+ return sequences, token_timestamps
2188
+ return sequences
2189
+
2190
+ for current_segment_list in current_segments:
2191
+ if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
2192
+ sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
2193
+ if return_token_timestamps:
2194
+ token_timestamps = torch.cat(
2195
+ [d["result"]["token_timestamps"][d["idxs"][0] : d["idxs"][1]] for d in current_segment_list],
2196
+ dim=-1,
2197
+ )
2198
+
2199
+ if cut_off_length is not None:
2200
+ sequence = sequence[-cut_off_length:]
2201
+ if return_token_timestamps:
2202
+ token_timestamps = token_timestamps[-cut_off_length:]
2203
+
2204
+ if bos_token_tensor is not None:
2205
+ sequence = torch.cat([bos_token_tensor, sequence])
2206
+ if return_token_timestamps:
2207
+ token_timestamps = torch.cat(
2208
+ [torch.ones_like(bos_token_tensor, device=device) * 0.0, token_timestamps]
2209
+ )
2210
+ sequences.append(sequence)
2211
+ if return_token_timestamps:
2212
+ token_timestamps_list.append(token_timestamps)
2213
+ max_total_length = max(max_total_length, len(sequences[-1]))
2214
+ elif bos_token_tensor is not None:
2215
+ sequences.append(bos_token_tensor)
2216
+ if return_token_timestamps:
2217
+ token_timestamps_list.append(torch.ones_like(bos_token_tensor, device=device) * 0.0)
2218
+ else:
2219
+ sequences.append(torch.tensor([], device=device))
2220
+ if return_token_timestamps:
2221
+ token_timestamps_list.append(torch.tensor([], device=device))
2222
+
2223
+ max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
2224
+ for i in range(len(current_segments)):
2225
+ pad_length = max_total_length - len(sequences[i])
2226
+ pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
2227
+
2228
+ sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
2229
+ if return_token_timestamps:
2230
+ token_timestamps_list[i] = F.pad(
2231
+ token_timestamps_list[i],
2232
+ pad=pad,
2233
+ value=token_timestamps_list[i][-1] if len(token_timestamps_list[i]) > 0 else 0.0,
2234
+ )
2235
+
2236
+ sequences = torch.stack(sequences, dim=0)
2237
+
2238
+ if return_token_timestamps:
2239
+ token_timestamps = torch.stack(token_timestamps_list, dim=0)
2240
+ return sequences, token_timestamps
2241
+ else:
2242
+ return sequences
2243
+
2244
+ padded_outputs = _pad_to_max_length(
2245
+ current_segments=final_segments,
2246
+ pad_token_id=generation_config.pad_token_id,
2247
+ device=self.device,
2248
+ padding_side="right",
2249
+ return_token_timestamps=return_token_timestamps,
2250
+ force_unique_generate_call=force_unique_generate_call,
2251
+ )
2252
+
2253
+ if return_dict_in_generate and generation_config.return_dict_in_generate:
2254
+ logger.warning_once(
2255
+ "You have passed `return_dict_in_generate=True` and `return_timestamps=True`, this automatically sets `return_segments=True` to access the resuls of the underlying calls to GenerationMixin's generate in the returned `segments`."
2256
+ )
2257
+ return_segments = True
2258
+ elif not return_segments and not return_token_timestamps:
2259
+ return padded_outputs
2260
+
2261
+ if return_token_timestamps:
2262
+ sequences, token_timestamps = padded_outputs
2263
+ outputs = {
2264
+ "sequences": sequences,
2265
+ "token_timestamps": token_timestamps,
2266
+ }
2267
+ else:
2268
+ sequences = padded_outputs
2269
+ outputs = {
2270
+ "sequences": sequences,
2271
+ }
2272
+
2273
+ if return_segments:
2274
+ outputs["segments"] = final_segments
2275
+
2276
+ return outputs
2277
 
2278
  def get_encoder(self):
2279
  return self.model.get_encoder()