update generate method of mixin
Browse files
model.py
CHANGED
@@ -1801,17 +1801,61 @@ class WhisperModel(WhisperPreTrainedModel):
|
|
1801 |
encoder_attentions=encoder_outputs.attentions,
|
1802 |
)
|
1803 |
|
1804 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1805 |
|
1806 |
# CUSTOM (patch the generation method)
|
1807 |
class CustomWhisperGenerationMixin(WhisperGenerationMixin):
|
1808 |
def generate(
|
1809 |
self,
|
1810 |
input_features: Optional[torch.Tensor] = None,
|
1811 |
-
generation_config: Optional[
|
1812 |
logits_processor: Optional[LogitsProcessorList] = None,
|
1813 |
-
stopping_criteria: Optional[
|
1814 |
-
prefix_allowed_tokens_fn: Optional[
|
1815 |
synced_gpus: bool = False,
|
1816 |
return_timestamps: Optional[bool] = None,
|
1817 |
task: Optional[str] = None,
|
@@ -1827,11 +1871,9 @@ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
|
|
1827 |
num_segment_frames: Optional[int] = None,
|
1828 |
attention_mask: Optional[torch.Tensor] = None,
|
1829 |
time_precision: float = 0.02,
|
1830 |
-
time_precision_features: float = 0.01,
|
1831 |
return_token_timestamps: Optional[bool] = None,
|
1832 |
return_segments: bool = False,
|
1833 |
return_dict_in_generate: Optional[bool] = None,
|
1834 |
-
force_unique_generate_call: Optional[bool] = None,
|
1835 |
**kwargs,
|
1836 |
):
|
1837 |
# 0. deprecate old inputs
|
@@ -1846,7 +1888,7 @@ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
|
|
1846 |
generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
1847 |
|
1848 |
# 2. set global generate variables
|
1849 |
-
input_stride = self.model.encoder.
|
1850 |
num_segment_frames = input_stride * self.config.max_source_positions
|
1851 |
batch_size, total_input_frames = self._retrieve_total_input_frames(
|
1852 |
input_features=input_features, input_stride=input_stride, kwargs=kwargs
|
@@ -1898,21 +1940,11 @@ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
|
|
1898 |
# 3. Retrieve logits processors
|
1899 |
device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
|
1900 |
begin_index = init_tokens.shape[1]
|
1901 |
-
num_beams = kwargs.get(
|
1902 |
-
"num_beams",
|
1903 |
-
generation_config.num_beams
|
1904 |
-
if hasattr(generation_config, "num_beams") and generation_config.num_beams is not None
|
1905 |
-
else 1,
|
1906 |
-
)
|
1907 |
-
if "assistant_model" in kwargs:
|
1908 |
-
# speculative decoding: the model should be able to return eos token
|
1909 |
-
generation_config.begin_suppress_tokens = None
|
1910 |
-
|
1911 |
logits_processor = self._retrieve_logit_processors(
|
1912 |
generation_config=generation_config,
|
1913 |
logits_processor=logits_processor,
|
1914 |
begin_index=begin_index, # begin index is index of first generated decoder token
|
1915 |
-
num_beams=num_beams,
|
1916 |
device=device,
|
1917 |
)
|
1918 |
|
@@ -1956,19 +1988,6 @@ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
|
|
1956 |
batch_size=cur_bsz,
|
1957 |
generation_config=generation_config,
|
1958 |
)
|
1959 |
-
# 5bis speculative decoding: ensure the assistant model does only one call to generate and therefore returns decoder input token ids and eos token id
|
1960 |
-
# we set a flag in the generation config to force the model to make only one call to generate and return the decoder input token ids and eos token id
|
1961 |
-
if "assistant_model" in kwargs:
|
1962 |
-
assistant_model = kwargs["assistant_model"]
|
1963 |
-
assistant_model.generation_config.force_unique_generate_call = True
|
1964 |
-
|
1965 |
-
if force_unique_generate_call is None:
|
1966 |
-
if hasattr(generation_config, "force_unique_generate_call"):
|
1967 |
-
force_unique_generate_call = generation_config.force_unique_generate_call
|
1968 |
-
elif hasattr(self.generation_config, "force_unique_generate_call"):
|
1969 |
-
force_unique_generate_call = self.generation_config.force_unique_generate_call
|
1970 |
-
else:
|
1971 |
-
force_unique_generate_call = False
|
1972 |
|
1973 |
# 6 Transcribe audio until we reach the end of all input audios
|
1974 |
while (seek < max_frames).any():
|
@@ -1983,9 +2002,7 @@ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
|
|
1983 |
cur_bsz=cur_bsz,
|
1984 |
batch_idx_map=batch_idx_map,
|
1985 |
)
|
1986 |
-
time_offset =
|
1987 |
-
seek.to(torch.float32 if device.type == "mps" else torch.float64) * time_precision / input_stride
|
1988 |
-
)
|
1989 |
seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
|
1990 |
|
1991 |
# 6.2 cut out next 30s segment from input features
|
@@ -1997,13 +2014,6 @@ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
|
|
1997 |
cur_bsz=cur_bsz,
|
1998 |
batch_idx_map=batch_idx_map,
|
1999 |
)
|
2000 |
-
|
2001 |
-
def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name):
|
2002 |
-
if logits_processor is not None:
|
2003 |
-
logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None)
|
2004 |
-
if logit_processor:
|
2005 |
-
return getattr(logit_processor, attribute_name, None)
|
2006 |
-
return None
|
2007 |
|
2008 |
# 6.3 prepare decoder input ids
|
2009 |
suppress_tokens = _get_attr_from_logit_processors(
|
@@ -2021,7 +2031,6 @@ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
|
|
2021 |
config=self.config,
|
2022 |
device=init_tokens.device,
|
2023 |
suppress_tokens=suppress_tokens,
|
2024 |
-
timestamp_begin=timestamp_begin,
|
2025 |
kwargs=kwargs,
|
2026 |
)
|
2027 |
|
@@ -2082,20 +2091,18 @@ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
|
|
2082 |
timestamp_begin=timestamp_begin,
|
2083 |
seek_num_frames=seek_num_frames,
|
2084 |
time_precision=time_precision,
|
2085 |
-
time_precision_features=time_precision_features,
|
2086 |
input_stride=input_stride,
|
2087 |
prev_idx=prev_i,
|
2088 |
idx=i,
|
2089 |
return_token_timestamps=return_token_timestamps,
|
2090 |
-
decoder_input_ids=decoder_input_ids,
|
2091 |
)
|
2092 |
|
2093 |
-
seek[prev_i] += segment_offset
|
2094 |
-
|
2095 |
current_segments[prev_i] += segments
|
2096 |
|
2097 |
-
|
2098 |
-
|
|
|
|
|
2099 |
|
2100 |
# 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
|
2101 |
# output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
|
@@ -2105,154 +2112,51 @@ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
|
|
2105 |
else current_segments
|
2106 |
)
|
2107 |
|
2108 |
-
|
2109 |
-
|
2110 |
-
|
2111 |
-
if (
|
2112 |
-
return_dict_in_generate
|
2113 |
-
and generation_config.return_dict_in_generate
|
2114 |
-
and (force_unique_generate_call or not return_timestamps)
|
2115 |
-
):
|
2116 |
-
# only one call to generate_with_fallback, we can return a ModelOutput
|
2117 |
-
outputs = self._stack_split_outputs(seek_outputs, model_output_type, self.device, kwargs)
|
2118 |
-
if num_return_sequences > 1:
|
2119 |
-
if hasattr(outputs, "encoder_attentions") and outputs.encoder_attentions is not None:
|
2120 |
-
outputs.encoder_attentions = tuple(
|
2121 |
-
outputs.encoder_attentions[i][::num_return_sequences]
|
2122 |
-
for i in range(len(outputs.encoder_attentions))
|
2123 |
-
)
|
2124 |
-
if hasattr(outputs, "encoder_hidden_states") and outputs.encoder_hidden_states is not None:
|
2125 |
-
outputs.encoder_hidden_states = tuple(
|
2126 |
-
outputs.encoder_hidden_states[i][::num_return_sequences]
|
2127 |
-
for i in range(len(outputs.encoder_hidden_states))
|
2128 |
-
)
|
2129 |
-
return outputs
|
2130 |
-
|
2131 |
-
def _pad_to_max_length(
|
2132 |
-
current_segments,
|
2133 |
-
pad_token_id,
|
2134 |
-
device,
|
2135 |
-
padding_side="right",
|
2136 |
-
padding="longest",
|
2137 |
-
bos_token_tensor=None,
|
2138 |
-
cut_off_length=None,
|
2139 |
-
return_token_timestamps=False,
|
2140 |
-
force_unique_generate_call=False,
|
2141 |
-
):
|
2142 |
-
max_total_length = 0
|
2143 |
-
sequences = []
|
2144 |
-
token_timestamps_list = []
|
2145 |
-
|
2146 |
-
if padding_side not in ["right", "left"]:
|
2147 |
-
raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")
|
2148 |
-
|
2149 |
-
if padding not in ["longest", "max_length"]:
|
2150 |
-
raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}")
|
2151 |
-
elif padding == "max_length" and cut_off_length is None:
|
2152 |
-
raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")
|
2153 |
-
|
2154 |
-
if force_unique_generate_call:
|
2155 |
-
sequences_list = []
|
2156 |
-
timestamps_list = []
|
2157 |
-
for segments in current_segments:
|
2158 |
-
result = segments[0]["result"]
|
2159 |
-
sequences_list.append(result if isinstance(result, torch.Tensor) else result["sequences"])
|
2160 |
-
if return_token_timestamps:
|
2161 |
-
timestamps_list.append(result["token_timestamps"])
|
2162 |
-
|
2163 |
-
sequences = torch.stack(sequences_list, dim=0)
|
2164 |
-
if return_token_timestamps:
|
2165 |
-
token_timestamps = torch.stack(timestamps_list, dim=0)
|
2166 |
-
return sequences, token_timestamps
|
2167 |
-
return sequences
|
2168 |
-
|
2169 |
-
for current_segment_list in current_segments:
|
2170 |
-
if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
|
2171 |
-
sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
|
2172 |
-
if return_token_timestamps:
|
2173 |
-
token_timestamps = torch.cat(
|
2174 |
-
[d["result"]["token_timestamps"][d["idxs"][0] : d["idxs"][1]] for d in current_segment_list],
|
2175 |
-
dim=-1,
|
2176 |
-
)
|
2177 |
-
|
2178 |
-
if cut_off_length is not None:
|
2179 |
-
sequence = sequence[-cut_off_length:]
|
2180 |
-
if return_token_timestamps:
|
2181 |
-
token_timestamps = token_timestamps[-cut_off_length:]
|
2182 |
-
|
2183 |
-
if bos_token_tensor is not None:
|
2184 |
-
sequence = torch.cat([bos_token_tensor, sequence])
|
2185 |
-
if return_token_timestamps:
|
2186 |
-
token_timestamps = torch.cat(
|
2187 |
-
[torch.ones_like(bos_token_tensor, device=device) * 0.0, token_timestamps]
|
2188 |
-
)
|
2189 |
-
sequences.append(sequence)
|
2190 |
-
if return_token_timestamps:
|
2191 |
-
token_timestamps_list.append(token_timestamps)
|
2192 |
-
max_total_length = max(max_total_length, len(sequences[-1]))
|
2193 |
-
elif bos_token_tensor is not None:
|
2194 |
-
sequences.append(bos_token_tensor)
|
2195 |
-
if return_token_timestamps:
|
2196 |
-
token_timestamps_list.append(torch.ones_like(bos_token_tensor, device=device) * 0.0)
|
2197 |
-
else:
|
2198 |
-
sequences.append(torch.tensor([], device=device))
|
2199 |
-
if return_token_timestamps:
|
2200 |
-
token_timestamps_list.append(torch.tensor([], device=device))
|
2201 |
-
|
2202 |
-
max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
|
2203 |
-
for i in range(len(current_segments)):
|
2204 |
-
pad_length = max_total_length - len(sequences[i])
|
2205 |
-
pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
|
2206 |
|
2207 |
-
|
2208 |
-
|
2209 |
-
|
2210 |
-
token_timestamps_list[i],
|
2211 |
-
pad=pad,
|
2212 |
-
value=token_timestamps_list[i][-1] if len(token_timestamps_list[i]) > 0 else 0.0,
|
2213 |
-
)
|
2214 |
|
2215 |
-
|
|
|
|
|
|
|
|
|
2216 |
|
2217 |
if return_token_timestamps:
|
2218 |
-
|
2219 |
-
|
|
|
2220 |
else:
|
2221 |
-
|
2222 |
|
2223 |
-
|
2224 |
-
|
2225 |
-
pad_token_id=generation_config.pad_token_id,
|
2226 |
-
device=self.device,
|
2227 |
-
padding_side="right",
|
2228 |
-
return_token_timestamps=return_token_timestamps,
|
2229 |
-
force_unique_generate_call=force_unique_generate_call,
|
2230 |
-
)
|
2231 |
|
2232 |
-
|
2233 |
-
|
2234 |
-
|
2235 |
-
|
2236 |
-
|
2237 |
-
|
2238 |
-
|
2239 |
-
|
2240 |
-
|
2241 |
-
|
2242 |
-
|
2243 |
-
|
2244 |
-
|
2245 |
-
|
2246 |
-
|
2247 |
-
|
2248 |
-
|
2249 |
-
"sequences": sequences,
|
2250 |
-
}
|
2251 |
|
2252 |
-
|
2253 |
-
outputs["segments"] = final_segments
|
2254 |
|
2255 |
-
return
|
2256 |
|
2257 |
@add_start_docstrings(
|
2258 |
"The Whisper Model with a language modeling head. Can be used for automatic speech recognition.",
|
@@ -2270,132 +2174,6 @@ class CustomWhisperForConditionalGeneration(CustomWhisperGenerationMixin, Whispe
|
|
2270 |
|
2271 |
# Initialize weights and apply final processing
|
2272 |
self.post_init()
|
2273 |
-
|
2274 |
-
def _pad_to_max_length(
|
2275 |
-
current_segments,
|
2276 |
-
pad_token_id,
|
2277 |
-
device,
|
2278 |
-
padding_side="right",
|
2279 |
-
padding="longest",
|
2280 |
-
bos_token_tensor=None,
|
2281 |
-
cut_off_length=None,
|
2282 |
-
return_token_timestamps=False,
|
2283 |
-
force_unique_generate_call=False,
|
2284 |
-
):
|
2285 |
-
max_total_length = 0
|
2286 |
-
sequences = []
|
2287 |
-
token_timestamps_list = []
|
2288 |
-
|
2289 |
-
if padding_side not in ["right", "left"]:
|
2290 |
-
raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")
|
2291 |
-
|
2292 |
-
if padding not in ["longest", "max_length"]:
|
2293 |
-
raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}")
|
2294 |
-
elif padding == "max_length" and cut_off_length is None:
|
2295 |
-
raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")
|
2296 |
-
|
2297 |
-
if force_unique_generate_call:
|
2298 |
-
sequences_list = []
|
2299 |
-
timestamps_list = []
|
2300 |
-
for segments in current_segments:
|
2301 |
-
result = segments[0]["result"]
|
2302 |
-
sequences_list.append(result if isinstance(result, torch.Tensor) else result["sequences"])
|
2303 |
-
if return_token_timestamps:
|
2304 |
-
timestamps_list.append(result["token_timestamps"])
|
2305 |
-
|
2306 |
-
sequences = torch.stack(sequences_list, dim=0)
|
2307 |
-
if return_token_timestamps:
|
2308 |
-
token_timestamps = torch.stack(timestamps_list, dim=0)
|
2309 |
-
return sequences, token_timestamps
|
2310 |
-
return sequences
|
2311 |
-
|
2312 |
-
for current_segment_list in current_segments:
|
2313 |
-
if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
|
2314 |
-
sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
|
2315 |
-
if return_token_timestamps:
|
2316 |
-
token_timestamps = torch.cat(
|
2317 |
-
[d["result"]["token_timestamps"][d["idxs"][0] : d["idxs"][1]] for d in current_segment_list],
|
2318 |
-
dim=-1,
|
2319 |
-
)
|
2320 |
-
|
2321 |
-
if cut_off_length is not None:
|
2322 |
-
sequence = sequence[-cut_off_length:]
|
2323 |
-
if return_token_timestamps:
|
2324 |
-
token_timestamps = token_timestamps[-cut_off_length:]
|
2325 |
-
|
2326 |
-
if bos_token_tensor is not None:
|
2327 |
-
sequence = torch.cat([bos_token_tensor, sequence])
|
2328 |
-
if return_token_timestamps:
|
2329 |
-
token_timestamps = torch.cat(
|
2330 |
-
[torch.ones_like(bos_token_tensor, device=device) * 0.0, token_timestamps]
|
2331 |
-
)
|
2332 |
-
sequences.append(sequence)
|
2333 |
-
if return_token_timestamps:
|
2334 |
-
token_timestamps_list.append(token_timestamps)
|
2335 |
-
max_total_length = max(max_total_length, len(sequences[-1]))
|
2336 |
-
elif bos_token_tensor is not None:
|
2337 |
-
sequences.append(bos_token_tensor)
|
2338 |
-
if return_token_timestamps:
|
2339 |
-
token_timestamps_list.append(torch.ones_like(bos_token_tensor, device=device) * 0.0)
|
2340 |
-
else:
|
2341 |
-
sequences.append(torch.tensor([], device=device))
|
2342 |
-
if return_token_timestamps:
|
2343 |
-
token_timestamps_list.append(torch.tensor([], device=device))
|
2344 |
-
|
2345 |
-
max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
|
2346 |
-
for i in range(len(current_segments)):
|
2347 |
-
pad_length = max_total_length - len(sequences[i])
|
2348 |
-
pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
|
2349 |
-
|
2350 |
-
sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
|
2351 |
-
if return_token_timestamps:
|
2352 |
-
token_timestamps_list[i] = F.pad(
|
2353 |
-
token_timestamps_list[i],
|
2354 |
-
pad=pad,
|
2355 |
-
value=token_timestamps_list[i][-1] if len(token_timestamps_list[i]) > 0 else 0.0,
|
2356 |
-
)
|
2357 |
-
|
2358 |
-
sequences = torch.stack(sequences, dim=0)
|
2359 |
-
|
2360 |
-
if return_token_timestamps:
|
2361 |
-
token_timestamps = torch.stack(token_timestamps_list, dim=0)
|
2362 |
-
return sequences, token_timestamps
|
2363 |
-
else:
|
2364 |
-
return sequences
|
2365 |
-
|
2366 |
-
padded_outputs = _pad_to_max_length(
|
2367 |
-
current_segments=final_segments,
|
2368 |
-
pad_token_id=generation_config.pad_token_id,
|
2369 |
-
device=self.device,
|
2370 |
-
padding_side="right",
|
2371 |
-
return_token_timestamps=return_token_timestamps,
|
2372 |
-
force_unique_generate_call=force_unique_generate_call,
|
2373 |
-
)
|
2374 |
-
|
2375 |
-
if return_dict_in_generate and generation_config.return_dict_in_generate:
|
2376 |
-
logger.warning_once(
|
2377 |
-
"You have passed `return_dict_in_generate=True` and `return_timestamps=True`, this automatically sets `return_segments=True` to access the resuls of the underlying calls to GenerationMixin's generate in the returned `segments`."
|
2378 |
-
)
|
2379 |
-
return_segments = True
|
2380 |
-
elif not return_segments and not return_token_timestamps:
|
2381 |
-
return padded_outputs
|
2382 |
-
|
2383 |
-
if return_token_timestamps:
|
2384 |
-
sequences, token_timestamps = padded_outputs
|
2385 |
-
outputs = {
|
2386 |
-
"sequences": sequences,
|
2387 |
-
"token_timestamps": token_timestamps,
|
2388 |
-
}
|
2389 |
-
else:
|
2390 |
-
sequences = padded_outputs
|
2391 |
-
outputs = {
|
2392 |
-
"sequences": sequences,
|
2393 |
-
}
|
2394 |
-
|
2395 |
-
if return_segments:
|
2396 |
-
outputs["segments"] = final_segments
|
2397 |
-
|
2398 |
-
return outputs
|
2399 |
|
2400 |
def get_encoder(self):
|
2401 |
return self.model.get_encoder()
|
|
|
1801 |
encoder_attentions=encoder_outputs.attentions,
|
1802 |
)
|
1803 |
|
1804 |
+
def _pad_to_max_length(
|
1805 |
+
current_segments,
|
1806 |
+
pad_token_id,
|
1807 |
+
device,
|
1808 |
+
padding_side="right",
|
1809 |
+
padding="longest",
|
1810 |
+
bos_token_tensor=None,
|
1811 |
+
cut_off_length=None,
|
1812 |
+
):
|
1813 |
+
max_total_length = 0
|
1814 |
+
sequences = []
|
1815 |
+
|
1816 |
+
if padding_side not in ["right", "left"]:
|
1817 |
+
raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")
|
1818 |
+
|
1819 |
+
if padding not in ["longest", "max_length"]:
|
1820 |
+
raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}")
|
1821 |
+
elif padding == "max_length" and cut_off_length is None:
|
1822 |
+
raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")
|
1823 |
+
|
1824 |
+
for current_segment_list in current_segments:
|
1825 |
+
if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
|
1826 |
+
sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
|
1827 |
+
|
1828 |
+
if cut_off_length is not None:
|
1829 |
+
sequence = sequence[-cut_off_length:]
|
1830 |
+
|
1831 |
+
if bos_token_tensor is not None:
|
1832 |
+
sequence = torch.cat([bos_token_tensor, sequence])
|
1833 |
+
|
1834 |
+
sequences.append(sequence)
|
1835 |
+
max_total_length = max(max_total_length, len(sequences[-1]))
|
1836 |
+
elif bos_token_tensor is not None:
|
1837 |
+
sequences.append(bos_token_tensor)
|
1838 |
+
else:
|
1839 |
+
sequences.append(torch.tensor([], device=device))
|
1840 |
+
|
1841 |
+
max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
|
1842 |
+
for i in range(len(current_segments)):
|
1843 |
+
pad_length = max_total_length - len(sequences[i])
|
1844 |
+
pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
|
1845 |
+
sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
|
1846 |
+
|
1847 |
+
sequences = torch.stack(sequences, dim=0)
|
1848 |
+
return sequences
|
1849 |
|
1850 |
# CUSTOM (patch the generation method)
|
1851 |
class CustomWhisperGenerationMixin(WhisperGenerationMixin):
|
1852 |
def generate(
|
1853 |
self,
|
1854 |
input_features: Optional[torch.Tensor] = None,
|
1855 |
+
generation_config: Optional[GenerationConfig] = None,
|
1856 |
logits_processor: Optional[LogitsProcessorList] = None,
|
1857 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
1858 |
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
1859 |
synced_gpus: bool = False,
|
1860 |
return_timestamps: Optional[bool] = None,
|
1861 |
task: Optional[str] = None,
|
|
|
1871 |
num_segment_frames: Optional[int] = None,
|
1872 |
attention_mask: Optional[torch.Tensor] = None,
|
1873 |
time_precision: float = 0.02,
|
|
|
1874 |
return_token_timestamps: Optional[bool] = None,
|
1875 |
return_segments: bool = False,
|
1876 |
return_dict_in_generate: Optional[bool] = None,
|
|
|
1877 |
**kwargs,
|
1878 |
):
|
1879 |
# 0. deprecate old inputs
|
|
|
1888 |
generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
1889 |
|
1890 |
# 2. set global generate variables
|
1891 |
+
input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
|
1892 |
num_segment_frames = input_stride * self.config.max_source_positions
|
1893 |
batch_size, total_input_frames = self._retrieve_total_input_frames(
|
1894 |
input_features=input_features, input_stride=input_stride, kwargs=kwargs
|
|
|
1940 |
# 3. Retrieve logits processors
|
1941 |
device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
|
1942 |
begin_index = init_tokens.shape[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1943 |
logits_processor = self._retrieve_logit_processors(
|
1944 |
generation_config=generation_config,
|
1945 |
logits_processor=logits_processor,
|
1946 |
begin_index=begin_index, # begin index is index of first generated decoder token
|
1947 |
+
num_beams=kwargs.get("num_beams", 1),
|
1948 |
device=device,
|
1949 |
)
|
1950 |
|
|
|
1988 |
batch_size=cur_bsz,
|
1989 |
generation_config=generation_config,
|
1990 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1991 |
|
1992 |
# 6 Transcribe audio until we reach the end of all input audios
|
1993 |
while (seek < max_frames).any():
|
|
|
2002 |
cur_bsz=cur_bsz,
|
2003 |
batch_idx_map=batch_idx_map,
|
2004 |
)
|
2005 |
+
time_offset = seek * time_precision / input_stride
|
|
|
|
|
2006 |
seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
|
2007 |
|
2008 |
# 6.2 cut out next 30s segment from input features
|
|
|
2014 |
cur_bsz=cur_bsz,
|
2015 |
batch_idx_map=batch_idx_map,
|
2016 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2017 |
|
2018 |
# 6.3 prepare decoder input ids
|
2019 |
suppress_tokens = _get_attr_from_logit_processors(
|
|
|
2031 |
config=self.config,
|
2032 |
device=init_tokens.device,
|
2033 |
suppress_tokens=suppress_tokens,
|
|
|
2034 |
kwargs=kwargs,
|
2035 |
)
|
2036 |
|
|
|
2091 |
timestamp_begin=timestamp_begin,
|
2092 |
seek_num_frames=seek_num_frames,
|
2093 |
time_precision=time_precision,
|
|
|
2094 |
input_stride=input_stride,
|
2095 |
prev_idx=prev_i,
|
2096 |
idx=i,
|
2097 |
return_token_timestamps=return_token_timestamps,
|
|
|
2098 |
)
|
2099 |
|
|
|
|
|
2100 |
current_segments[prev_i] += segments
|
2101 |
|
2102 |
+
if is_shortform:
|
2103 |
+
seek[prev_i] += max_frames[i]
|
2104 |
+
else:
|
2105 |
+
seek[prev_i] += segment_offset
|
2106 |
|
2107 |
# 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
|
2108 |
# output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
|
|
|
2112 |
else current_segments
|
2113 |
)
|
2114 |
|
2115 |
+
sequences = _pad_to_max_length(
|
2116 |
+
final_segments, generation_config.pad_token_id, device=self.device, padding_side="right"
|
2117 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2118 |
|
2119 |
+
# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
|
2120 |
+
if return_segments:
|
2121 |
+
return {"sequences": sequences, "segments": final_segments}
|
|
|
|
|
|
|
|
|
2122 |
|
2123 |
+
if is_shortform:
|
2124 |
+
# add eos token:
|
2125 |
+
if generation_config.max_new_tokens is None and generation_config.max_length is None:
|
2126 |
+
eos_tokens = torch.full((sequences.shape[0], 1), generation_config.eos_token_id)
|
2127 |
+
sequences = torch.cat([sequences, eos_tokens], dim=-1)
|
2128 |
|
2129 |
if return_token_timestamps:
|
2130 |
+
outputs = {}
|
2131 |
+
outputs["sequences"] = sequences
|
2132 |
+
outputs["token_timestamps"] = torch.stack([d["token_timestamps"] for d in seek_outputs], dim=0)
|
2133 |
else:
|
2134 |
+
outputs = sequences
|
2135 |
|
2136 |
+
if return_dict_in_generate and generation_config.return_dict_in_generate:
|
2137 |
+
dict_outputs = self._stack_split_outputs(seek_outputs, model_output_type, sequences.device, kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
2138 |
|
2139 |
+
if num_return_sequences > 1:
|
2140 |
+
if hasattr(dict_outputs, "encoder_attentions") and dict_outputs.encoder_attentions is not None:
|
2141 |
+
dict_outputs.encoder_attentions = tuple(
|
2142 |
+
dict_outputs.encoder_attentions[i][::num_return_sequences]
|
2143 |
+
for i in range(len(dict_outputs.encoder_attentions))
|
2144 |
+
)
|
2145 |
+
if (
|
2146 |
+
hasattr(dict_outputs, "encoder_hidden_states")
|
2147 |
+
and dict_outputs.encoder_hidden_states is not None
|
2148 |
+
):
|
2149 |
+
dict_outputs.encoder_hidden_states = tuple(
|
2150 |
+
dict_outputs.encoder_hidden_states[i][::num_return_sequences]
|
2151 |
+
for i in range(len(dict_outputs.encoder_hidden_states))
|
2152 |
+
)
|
2153 |
+
if return_token_timestamps:
|
2154 |
+
dict_outputs["token_timestamps"] = outputs["token_timestamps"]
|
2155 |
+
return dict_outputs
|
|
|
|
|
2156 |
|
2157 |
+
return outputs
|
|
|
2158 |
|
2159 |
+
return sequences
|
2160 |
|
2161 |
@add_start_docstrings(
|
2162 |
"The Whisper Model with a language modeling head. Can be used for automatic speech recognition.",
|
|
|
2174 |
|
2175 |
# Initialize weights and apply final processing
|
2176 |
self.post_init()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2177 |
|
2178 |
def get_encoder(self):
|
2179 |
return self.model.get_encoder()
|