christopher-hoernle
commited on
Commit
·
51ea6ce
1
Parent(s):
793fd0d
fixes
Browse files
model.py
CHANGED
@@ -1,6 +1,11 @@
|
|
1 |
from transformers.models.whisper.configuration_whisper import WhisperConfig
|
2 |
-
|
3 |
-
import
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
"""Custom config to support modification of the Whisper encoder."""
|
6 |
|
@@ -1152,8 +1157,7 @@ class WhisperEncoder(WhisperPreTrainedModel):
|
|
1152 |
)
|
1153 |
|
1154 |
inputs_embeds = inputs_embeds.permute(0, 2, 1)
|
1155 |
-
|
1156 |
-
|
1157 |
sequence_length = hidden_states.shape[1]
|
1158 |
|
1159 |
# CUSTOM
|
@@ -1811,69 +1815,465 @@ class CustomWhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTr
|
|
1811 |
self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
1812 |
self.max_target_positions = config.max_target_positions
|
1813 |
|
1814 |
-
self.patch_generate()
|
1815 |
-
|
1816 |
# Initialize weights and apply final processing
|
1817 |
self.post_init()
|
1818 |
|
1819 |
-
# CUSTOM (
|
1820 |
-
def
|
1821 |
-
"""
|
1822 |
-
|
1823 |
-
|
1824 |
-
|
1825 |
-
|
1826 |
-
def get_conv_stride(self):
|
1827 |
-
"""Calculate total stride of all conv layers"""
|
1828 |
-
total_stride = 1
|
1829 |
-
for layer in self.model.encoder.conv_layers:
|
1830 |
-
total_stride *= layer.stride[0]
|
1831 |
-
return total_stride
|
1832 |
|
1833 |
-
|
1834 |
-
|
1835 |
-
|
1836 |
-
|
1837 |
-
|
1838 |
-
|
1839 |
-
|
1840 |
-
|
1841 |
-
|
1842 |
-
|
1843 |
-
|
1844 |
-
|
1845 |
-
|
1846 |
-
|
1847 |
-
|
1848 |
-
|
1849 |
-
|
1850 |
-
|
1851 |
-
|
1852 |
-
|
1853 |
-
|
1854 |
-
|
1855 |
-
|
1856 |
-
|
1857 |
-
|
1858 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1859 |
)
|
1860 |
-
|
1861 |
-
|
1862 |
-
|
1863 |
-
|
1864 |
-
|
1865 |
-
|
1866 |
-
|
1867 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1868 |
)
|
1869 |
|
1870 |
-
|
1871 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1872 |
|
1873 |
-
|
1874 |
-
|
1875 |
-
|
1876 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1877 |
|
1878 |
def get_encoder(self):
|
1879 |
return self.model.get_encoder()
|
|
|
1 |
from transformers.models.whisper.configuration_whisper import WhisperConfig
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from transformers.generation.logits_process import (
|
4 |
+
LogitsProcessorList,
|
5 |
+
SuppressTokensLogitsProcessor
|
6 |
+
)
|
7 |
+
from typing import List, Optional, Dict, Any
|
8 |
+
import warnings
|
9 |
|
10 |
"""Custom config to support modification of the Whisper encoder."""
|
11 |
|
|
|
1157 |
)
|
1158 |
|
1159 |
inputs_embeds = inputs_embeds.permute(0, 2, 1)
|
1160 |
+
|
|
|
1161 |
sequence_length = hidden_states.shape[1]
|
1162 |
|
1163 |
# CUSTOM
|
|
|
1815 |
self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
1816 |
self.max_target_positions = config.max_target_positions
|
1817 |
|
|
|
|
|
1818 |
# Initialize weights and apply final processing
|
1819 |
self.post_init()
|
1820 |
|
1821 |
+
# CUSTOM (patch the generation method)
|
1822 |
+
def get_conv_stride(self):
|
1823 |
+
"""Calculate total stride of all conv layers"""
|
1824 |
+
total_stride = 1
|
1825 |
+
for layer in self.model.encoder.conv_layers:
|
1826 |
+
total_stride *= layer.stride[0]
|
1827 |
+
return total_stride
|
|
|
|
|
|
|
|
|
|
|
|
|
1828 |
|
1829 |
+
def generate(
|
1830 |
+
self,
|
1831 |
+
input_features: Optional[torch.Tensor] = None,
|
1832 |
+
generation_config: Optional[Any] = None,
|
1833 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
1834 |
+
stopping_criteria: Optional[Any] = None,
|
1835 |
+
prefix_allowed_tokens_fn: Optional[Any] = None,
|
1836 |
+
synced_gpus: bool = False,
|
1837 |
+
return_timestamps: Optional[bool] = None,
|
1838 |
+
task: Optional[str] = None,
|
1839 |
+
language: Optional[Union[str, List[str]]] = None,
|
1840 |
+
is_multilingual: Optional[bool] = None,
|
1841 |
+
prompt_ids: Optional[torch.Tensor] = None,
|
1842 |
+
prompt_condition_type: Optional[str] = None, # first-segment, all-segments
|
1843 |
+
condition_on_prev_tokens: Optional[bool] = None,
|
1844 |
+
temperature: Optional[Union[float, Tuple[float, ...]]] = None,
|
1845 |
+
compression_ratio_threshold: Optional[float] = None,
|
1846 |
+
logprob_threshold: Optional[float] = None,
|
1847 |
+
no_speech_threshold: Optional[float] = None,
|
1848 |
+
num_segment_frames: Optional[int] = None,
|
1849 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1850 |
+
time_precision: float = 0.02,
|
1851 |
+
time_precision_features: float = 0.01,
|
1852 |
+
return_token_timestamps: Optional[bool] = None,
|
1853 |
+
return_segments: bool = False,
|
1854 |
+
return_dict_in_generate: Optional[bool] = None,
|
1855 |
+
force_unique_generate_call: Optional[bool] = None,
|
1856 |
+
**kwargs,
|
1857 |
+
):
|
1858 |
+
# 0. deprecate old inputs
|
1859 |
+
if "inputs" in kwargs:
|
1860 |
+
input_features = kwargs.pop("inputs")
|
1861 |
+
warnings.warn(
|
1862 |
+
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
|
1863 |
+
FutureWarning,
|
1864 |
)
|
1865 |
+
|
1866 |
+
# 1. prepare generation config
|
1867 |
+
generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
1868 |
+
|
1869 |
+
# 2. set global generate variables
|
1870 |
+
input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
|
1871 |
+
num_segment_frames = input_stride * self.config.max_source_positions
|
1872 |
+
batch_size, total_input_frames = self._retrieve_total_input_frames(
|
1873 |
+
input_features=input_features, input_stride=input_stride, kwargs=kwargs
|
1874 |
+
)
|
1875 |
+
is_shortform = total_input_frames <= num_segment_frames
|
1876 |
+
|
1877 |
+
# 3. Make sure generation config is correctly set
|
1878 |
+
# Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
|
1879 |
+
return_dict_in_generate = self._set_return_outputs(
|
1880 |
+
return_dict_in_generate=return_dict_in_generate,
|
1881 |
+
return_token_timestamps=return_token_timestamps,
|
1882 |
+
logprob_threshold=logprob_threshold,
|
1883 |
+
generation_config=generation_config,
|
1884 |
+
)
|
1885 |
+
timestamp_begin = self._set_return_timestamps(
|
1886 |
+
return_timestamps=return_timestamps, is_shortform=is_shortform, generation_config=generation_config
|
1887 |
+
)
|
1888 |
+
self._set_language_and_task(
|
1889 |
+
language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
|
1890 |
+
)
|
1891 |
+
self._set_num_frames(
|
1892 |
+
return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
|
1893 |
+
)
|
1894 |
+
self._set_thresholds_and_condition(
|
1895 |
+
generation_config=generation_config,
|
1896 |
+
logprob_threshold=logprob_threshold,
|
1897 |
+
compression_ratio_threshold=compression_ratio_threshold,
|
1898 |
+
no_speech_threshold=no_speech_threshold,
|
1899 |
+
condition_on_prev_tokens=condition_on_prev_tokens,
|
1900 |
+
)
|
1901 |
+
self._set_prompt_condition_type(
|
1902 |
+
generation_config=generation_config,
|
1903 |
+
prompt_condition_type=prompt_condition_type,
|
1904 |
+
)
|
1905 |
+
|
1906 |
+
# pass self.config for backward compatibility
|
1907 |
+
init_tokens = self._retrieve_init_tokens(
|
1908 |
+
input_features,
|
1909 |
+
batch_size=batch_size,
|
1910 |
+
generation_config=generation_config,
|
1911 |
+
config=self.config,
|
1912 |
+
num_segment_frames=num_segment_frames,
|
1913 |
+
kwargs=kwargs,
|
1914 |
+
)
|
1915 |
+
# passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
|
1916 |
+
# where the input ids are handled explicitly by the generate method
|
1917 |
+
self._check_decoder_input_ids(kwargs=kwargs)
|
1918 |
+
|
1919 |
+
# 3. Retrieve logits processors
|
1920 |
+
device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
|
1921 |
+
begin_index = init_tokens.shape[1]
|
1922 |
+
num_beams = kwargs.get(
|
1923 |
+
"num_beams",
|
1924 |
+
generation_config.num_beams
|
1925 |
+
if hasattr(generation_config, "num_beams") and generation_config.num_beams is not None
|
1926 |
+
else 1,
|
1927 |
+
)
|
1928 |
+
if "assistant_model" in kwargs:
|
1929 |
+
# speculative decoding: the model should be able to return eos token
|
1930 |
+
generation_config.begin_suppress_tokens = None
|
1931 |
+
|
1932 |
+
logits_processor = self._retrieve_logit_processors(
|
1933 |
+
generation_config=generation_config,
|
1934 |
+
logits_processor=logits_processor,
|
1935 |
+
begin_index=begin_index, # begin index is index of first generated decoder token
|
1936 |
+
num_beams=num_beams,
|
1937 |
+
device=device,
|
1938 |
+
)
|
1939 |
+
|
1940 |
+
# 4 Set and retrieve global generation variables
|
1941 |
+
self._set_condition_on_prev_tokens(
|
1942 |
+
condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config
|
1943 |
+
)
|
1944 |
+
|
1945 |
+
temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature
|
1946 |
+
temperature = temperatures[0]
|
1947 |
+
|
1948 |
+
max_frames, seek = self._retrieve_max_frames_and_seek(
|
1949 |
+
batch_size=batch_size,
|
1950 |
+
attention_mask=attention_mask,
|
1951 |
+
total_input_frames=total_input_frames,
|
1952 |
+
is_shortform=is_shortform,
|
1953 |
+
)
|
1954 |
+
|
1955 |
+
# 5 Prepare running variables, list for generation
|
1956 |
+
num_return_sequences = generation_config.num_return_sequences
|
1957 |
+
(
|
1958 |
+
batch_idx_map,
|
1959 |
+
cur_bsz,
|
1960 |
+
input_features,
|
1961 |
+
seek,
|
1962 |
+
max_frames,
|
1963 |
+
init_tokens,
|
1964 |
+
do_condition_on_prev_tokens,
|
1965 |
+
) = self._expand_variables_for_generation(
|
1966 |
+
input_features=input_features,
|
1967 |
+
seek=seek,
|
1968 |
+
max_frames=max_frames,
|
1969 |
+
init_tokens=init_tokens,
|
1970 |
+
batch_size=batch_size,
|
1971 |
+
condition_on_prev_tokens=condition_on_prev_tokens,
|
1972 |
+
generation_config=generation_config,
|
1973 |
+
)
|
1974 |
+
|
1975 |
+
current_segments = self._prepare_segments(
|
1976 |
+
prompt_ids=prompt_ids,
|
1977 |
+
batch_size=cur_bsz,
|
1978 |
+
generation_config=generation_config,
|
1979 |
+
)
|
1980 |
+
# 5bis speculative decoding: ensure the assistant model does only one call to generate and therefore returns decoder input token ids and eos token id
|
1981 |
+
# we set a flag in the generation config to force the model to make only one call to generate and return the decoder input token ids and eos token id
|
1982 |
+
if "assistant_model" in kwargs:
|
1983 |
+
assistant_model = kwargs["assistant_model"]
|
1984 |
+
assistant_model.generation_config.force_unique_generate_call = True
|
1985 |
+
|
1986 |
+
if force_unique_generate_call is None:
|
1987 |
+
if hasattr(generation_config, "force_unique_generate_call"):
|
1988 |
+
force_unique_generate_call = generation_config.force_unique_generate_call
|
1989 |
+
elif hasattr(self.generation_config, "force_unique_generate_call"):
|
1990 |
+
force_unique_generate_call = self.generation_config.force_unique_generate_call
|
1991 |
+
else:
|
1992 |
+
force_unique_generate_call = False
|
1993 |
+
|
1994 |
+
# 6 Transcribe audio until we reach the end of all input audios
|
1995 |
+
while (seek < max_frames).any():
|
1996 |
+
# 6.1 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop
|
1997 |
+
# in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order
|
1998 |
+
# to know which original audio is being decoded
|
1999 |
+
# Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk
|
2000 |
+
input_features, cur_bsz, batch_idx_map = self._maybe_reduce_batch(
|
2001 |
+
input_features=input_features,
|
2002 |
+
seek=seek,
|
2003 |
+
max_frames=max_frames,
|
2004 |
+
cur_bsz=cur_bsz,
|
2005 |
+
batch_idx_map=batch_idx_map,
|
2006 |
+
)
|
2007 |
+
time_offset = (
|
2008 |
+
seek.to(torch.float32 if device.type == "mps" else torch.float64) * time_precision / input_stride
|
2009 |
+
)
|
2010 |
+
seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
|
2011 |
+
|
2012 |
+
# 6.2 cut out next 30s segment from input features
|
2013 |
+
segment_input = self._get_input_segment(
|
2014 |
+
input_features=input_features,
|
2015 |
+
seek=seek,
|
2016 |
+
seek_num_frames=seek_num_frames,
|
2017 |
+
num_segment_frames=num_segment_frames,
|
2018 |
+
cur_bsz=cur_bsz,
|
2019 |
+
batch_idx_map=batch_idx_map,
|
2020 |
)
|
2021 |
|
2022 |
+
def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name):
|
2023 |
+
if logits_processor is not None:
|
2024 |
+
logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None)
|
2025 |
+
if logit_processor:
|
2026 |
+
return getattr(logit_processor, attribute_name, None)
|
2027 |
+
return None
|
2028 |
+
|
2029 |
+
# 6.3 prepare decoder input ids
|
2030 |
+
suppress_tokens = _get_attr_from_logit_processors(
|
2031 |
+
logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens"
|
2032 |
+
)
|
2033 |
+
|
2034 |
+
decoder_input_ids, kwargs = self._prepare_decoder_input_ids(
|
2035 |
+
cur_bsz=cur_bsz,
|
2036 |
+
init_tokens=init_tokens,
|
2037 |
+
current_segments=current_segments,
|
2038 |
+
batch_idx_map=batch_idx_map,
|
2039 |
+
do_condition_on_prev_tokens=do_condition_on_prev_tokens,
|
2040 |
+
prompt_ids=prompt_ids,
|
2041 |
+
generation_config=generation_config,
|
2042 |
+
config=self.config,
|
2043 |
+
device=init_tokens.device,
|
2044 |
+
suppress_tokens=suppress_tokens,
|
2045 |
+
timestamp_begin=timestamp_begin,
|
2046 |
+
kwargs=kwargs,
|
2047 |
+
)
|
2048 |
+
|
2049 |
+
# 6.4 set max new tokens or max length
|
2050 |
+
self._set_max_new_tokens_and_length(
|
2051 |
+
config=self.config,
|
2052 |
+
decoder_input_ids=decoder_input_ids,
|
2053 |
+
generation_config=generation_config,
|
2054 |
+
)
|
2055 |
+
|
2056 |
+
# 6.5 Set current `begin_index` for all logit processors
|
2057 |
+
if logits_processor is not None:
|
2058 |
+
for proc in logits_processor:
|
2059 |
+
if hasattr(proc, "set_begin_index"):
|
2060 |
+
proc.set_begin_index(decoder_input_ids.shape[-1])
|
2061 |
+
|
2062 |
+
# 6.6 Run generate with fallback
|
2063 |
+
(
|
2064 |
+
seek_sequences,
|
2065 |
+
seek_outputs,
|
2066 |
+
should_skip,
|
2067 |
+
do_condition_on_prev_tokens,
|
2068 |
+
model_output_type,
|
2069 |
+
) = self.generate_with_fallback(
|
2070 |
+
segment_input=segment_input,
|
2071 |
+
decoder_input_ids=decoder_input_ids,
|
2072 |
+
cur_bsz=cur_bsz,
|
2073 |
+
batch_idx_map=batch_idx_map,
|
2074 |
+
seek=seek,
|
2075 |
+
num_segment_frames=num_segment_frames,
|
2076 |
+
max_frames=max_frames,
|
2077 |
+
temperatures=temperatures,
|
2078 |
+
generation_config=generation_config,
|
2079 |
+
logits_processor=logits_processor,
|
2080 |
+
stopping_criteria=stopping_criteria,
|
2081 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
2082 |
+
synced_gpus=synced_gpus,
|
2083 |
+
return_token_timestamps=return_token_timestamps,
|
2084 |
+
do_condition_on_prev_tokens=do_condition_on_prev_tokens,
|
2085 |
+
is_shortform=is_shortform,
|
2086 |
+
batch_size=batch_size,
|
2087 |
+
attention_mask=attention_mask,
|
2088 |
+
kwargs=kwargs,
|
2089 |
+
)
|
2090 |
+
|
2091 |
+
# 6.7 In every generated sequence, split by timestamp tokens and extract segments
|
2092 |
+
for i, seek_sequence in enumerate(seek_sequences):
|
2093 |
+
prev_i = batch_idx_map[i]
|
2094 |
+
|
2095 |
+
if should_skip[i]:
|
2096 |
+
seek[prev_i] += seek_num_frames[prev_i]
|
2097 |
+
continue
|
2098 |
+
|
2099 |
+
segments, segment_offset = self._retrieve_segment(
|
2100 |
+
seek_sequence=seek_sequence,
|
2101 |
+
seek_outputs=seek_outputs,
|
2102 |
+
time_offset=time_offset,
|
2103 |
+
timestamp_begin=timestamp_begin,
|
2104 |
+
seek_num_frames=seek_num_frames,
|
2105 |
+
time_precision=time_precision,
|
2106 |
+
time_precision_features=time_precision_features,
|
2107 |
+
input_stride=input_stride,
|
2108 |
+
prev_idx=prev_i,
|
2109 |
+
idx=i,
|
2110 |
+
return_token_timestamps=return_token_timestamps,
|
2111 |
+
decoder_input_ids=decoder_input_ids,
|
2112 |
+
)
|
2113 |
+
|
2114 |
+
seek[prev_i] += segment_offset
|
2115 |
+
|
2116 |
+
current_segments[prev_i] += segments
|
2117 |
+
|
2118 |
+
if force_unique_generate_call:
|
2119 |
+
break
|
2120 |
+
|
2121 |
+
# 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
|
2122 |
+
# output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
|
2123 |
+
final_segments = (
|
2124 |
+
[x[1:] for x in current_segments]
|
2125 |
+
if (prompt_ids is not None and generation_config.prompt_condition_type == "first-segment")
|
2126 |
+
else current_segments
|
2127 |
+
)
|
2128 |
+
|
2129 |
+
# if return_dict_in_generate=True and we forced a unique call to generate or return_timestamps=False, meaning we are sure only one call to generate has been made,
|
2130 |
+
# -> we can return a ModelOutput
|
2131 |
+
# otherwise, return_dict_in_generate is applied in the 'result' of each segment in final_segments
|
2132 |
+
if (
|
2133 |
+
return_dict_in_generate
|
2134 |
+
and generation_config.return_dict_in_generate
|
2135 |
+
and (force_unique_generate_call or not return_timestamps)
|
2136 |
+
):
|
2137 |
+
# only one call to generate_with_fallback, we can return a ModelOutput
|
2138 |
+
outputs = self._stack_split_outputs(seek_outputs, model_output_type, self.device, kwargs)
|
2139 |
+
if num_return_sequences > 1:
|
2140 |
+
if hasattr(outputs, "encoder_attentions") and outputs.encoder_attentions is not None:
|
2141 |
+
outputs.encoder_attentions = tuple(
|
2142 |
+
outputs.encoder_attentions[i][::num_return_sequences]
|
2143 |
+
for i in range(len(outputs.encoder_attentions))
|
2144 |
+
)
|
2145 |
+
if hasattr(outputs, "encoder_hidden_states") and outputs.encoder_hidden_states is not None:
|
2146 |
+
outputs.encoder_hidden_states = tuple(
|
2147 |
+
outputs.encoder_hidden_states[i][::num_return_sequences]
|
2148 |
+
for i in range(len(outputs.encoder_hidden_states))
|
2149 |
+
)
|
2150 |
+
return outputs
|
2151 |
|
2152 |
+
def _pad_to_max_length(
|
2153 |
+
current_segments,
|
2154 |
+
pad_token_id,
|
2155 |
+
device,
|
2156 |
+
padding_side="right",
|
2157 |
+
padding="longest",
|
2158 |
+
bos_token_tensor=None,
|
2159 |
+
cut_off_length=None,
|
2160 |
+
return_token_timestamps=False,
|
2161 |
+
force_unique_generate_call=False,
|
2162 |
+
):
|
2163 |
+
max_total_length = 0
|
2164 |
+
sequences = []
|
2165 |
+
token_timestamps_list = []
|
2166 |
+
|
2167 |
+
if padding_side not in ["right", "left"]:
|
2168 |
+
raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")
|
2169 |
+
|
2170 |
+
if padding not in ["longest", "max_length"]:
|
2171 |
+
raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}")
|
2172 |
+
elif padding == "max_length" and cut_off_length is None:
|
2173 |
+
raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")
|
2174 |
+
|
2175 |
+
if force_unique_generate_call:
|
2176 |
+
sequences_list = []
|
2177 |
+
timestamps_list = []
|
2178 |
+
for segments in current_segments:
|
2179 |
+
result = segments[0]["result"]
|
2180 |
+
sequences_list.append(result if isinstance(result, torch.Tensor) else result["sequences"])
|
2181 |
+
if return_token_timestamps:
|
2182 |
+
timestamps_list.append(result["token_timestamps"])
|
2183 |
+
|
2184 |
+
sequences = torch.stack(sequences_list, dim=0)
|
2185 |
+
if return_token_timestamps:
|
2186 |
+
token_timestamps = torch.stack(timestamps_list, dim=0)
|
2187 |
+
return sequences, token_timestamps
|
2188 |
+
return sequences
|
2189 |
+
|
2190 |
+
for current_segment_list in current_segments:
|
2191 |
+
if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
|
2192 |
+
sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
|
2193 |
+
if return_token_timestamps:
|
2194 |
+
token_timestamps = torch.cat(
|
2195 |
+
[d["result"]["token_timestamps"][d["idxs"][0] : d["idxs"][1]] for d in current_segment_list],
|
2196 |
+
dim=-1,
|
2197 |
+
)
|
2198 |
+
|
2199 |
+
if cut_off_length is not None:
|
2200 |
+
sequence = sequence[-cut_off_length:]
|
2201 |
+
if return_token_timestamps:
|
2202 |
+
token_timestamps = token_timestamps[-cut_off_length:]
|
2203 |
+
|
2204 |
+
if bos_token_tensor is not None:
|
2205 |
+
sequence = torch.cat([bos_token_tensor, sequence])
|
2206 |
+
if return_token_timestamps:
|
2207 |
+
token_timestamps = torch.cat(
|
2208 |
+
[torch.ones_like(bos_token_tensor, device=device) * 0.0, token_timestamps]
|
2209 |
+
)
|
2210 |
+
sequences.append(sequence)
|
2211 |
+
if return_token_timestamps:
|
2212 |
+
token_timestamps_list.append(token_timestamps)
|
2213 |
+
max_total_length = max(max_total_length, len(sequences[-1]))
|
2214 |
+
elif bos_token_tensor is not None:
|
2215 |
+
sequences.append(bos_token_tensor)
|
2216 |
+
if return_token_timestamps:
|
2217 |
+
token_timestamps_list.append(torch.ones_like(bos_token_tensor, device=device) * 0.0)
|
2218 |
+
else:
|
2219 |
+
sequences.append(torch.tensor([], device=device))
|
2220 |
+
if return_token_timestamps:
|
2221 |
+
token_timestamps_list.append(torch.tensor([], device=device))
|
2222 |
+
|
2223 |
+
max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
|
2224 |
+
for i in range(len(current_segments)):
|
2225 |
+
pad_length = max_total_length - len(sequences[i])
|
2226 |
+
pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
|
2227 |
+
|
2228 |
+
sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
|
2229 |
+
if return_token_timestamps:
|
2230 |
+
token_timestamps_list[i] = F.pad(
|
2231 |
+
token_timestamps_list[i],
|
2232 |
+
pad=pad,
|
2233 |
+
value=token_timestamps_list[i][-1] if len(token_timestamps_list[i]) > 0 else 0.0,
|
2234 |
+
)
|
2235 |
+
|
2236 |
+
sequences = torch.stack(sequences, dim=0)
|
2237 |
+
|
2238 |
+
if return_token_timestamps:
|
2239 |
+
token_timestamps = torch.stack(token_timestamps_list, dim=0)
|
2240 |
+
return sequences, token_timestamps
|
2241 |
+
else:
|
2242 |
+
return sequences
|
2243 |
+
|
2244 |
+
padded_outputs = _pad_to_max_length(
|
2245 |
+
current_segments=final_segments,
|
2246 |
+
pad_token_id=generation_config.pad_token_id,
|
2247 |
+
device=self.device,
|
2248 |
+
padding_side="right",
|
2249 |
+
return_token_timestamps=return_token_timestamps,
|
2250 |
+
force_unique_generate_call=force_unique_generate_call,
|
2251 |
+
)
|
2252 |
+
|
2253 |
+
if return_dict_in_generate and generation_config.return_dict_in_generate:
|
2254 |
+
logger.warning_once(
|
2255 |
+
"You have passed `return_dict_in_generate=True` and `return_timestamps=True`, this automatically sets `return_segments=True` to access the resuls of the underlying calls to GenerationMixin's generate in the returned `segments`."
|
2256 |
+
)
|
2257 |
+
return_segments = True
|
2258 |
+
elif not return_segments and not return_token_timestamps:
|
2259 |
+
return padded_outputs
|
2260 |
+
|
2261 |
+
if return_token_timestamps:
|
2262 |
+
sequences, token_timestamps = padded_outputs
|
2263 |
+
outputs = {
|
2264 |
+
"sequences": sequences,
|
2265 |
+
"token_timestamps": token_timestamps,
|
2266 |
+
}
|
2267 |
+
else:
|
2268 |
+
sequences = padded_outputs
|
2269 |
+
outputs = {
|
2270 |
+
"sequences": sequences,
|
2271 |
+
}
|
2272 |
+
|
2273 |
+
if return_segments:
|
2274 |
+
outputs["segments"] = final_segments
|
2275 |
+
|
2276 |
+
return outputs
|
2277 |
|
2278 |
def get_encoder(self):
|
2279 |
return self.model.get_encoder()
|