fix generate
Browse files
model.py
CHANGED
@@ -2130,6 +2130,132 @@ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
|
|
2130 |
for i in range(len(outputs.encoder_hidden_states))
|
2131 |
)
|
2132 |
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2133 |
|
2134 |
@add_start_docstrings(
|
2135 |
"The Whisper Model with a language modeling head. Can be used for automatic speech recognition.",
|
|
|
2130 |
for i in range(len(outputs.encoder_hidden_states))
|
2131 |
)
|
2132 |
return outputs
|
2133 |
+
|
2134 |
+
def _pad_to_max_length(
|
2135 |
+
current_segments,
|
2136 |
+
pad_token_id,
|
2137 |
+
device,
|
2138 |
+
padding_side="right",
|
2139 |
+
padding="longest",
|
2140 |
+
bos_token_tensor=None,
|
2141 |
+
cut_off_length=None,
|
2142 |
+
return_token_timestamps=False,
|
2143 |
+
force_unique_generate_call=False,
|
2144 |
+
):
|
2145 |
+
max_total_length = 0
|
2146 |
+
sequences = []
|
2147 |
+
token_timestamps_list = []
|
2148 |
+
|
2149 |
+
if padding_side not in ["right", "left"]:
|
2150 |
+
raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")
|
2151 |
+
|
2152 |
+
if padding not in ["longest", "max_length"]:
|
2153 |
+
raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}")
|
2154 |
+
elif padding == "max_length" and cut_off_length is None:
|
2155 |
+
raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")
|
2156 |
+
|
2157 |
+
if force_unique_generate_call:
|
2158 |
+
sequences_list = []
|
2159 |
+
timestamps_list = []
|
2160 |
+
for segments in current_segments:
|
2161 |
+
result = segments[0]["result"]
|
2162 |
+
sequences_list.append(result if isinstance(result, torch.Tensor) else result["sequences"])
|
2163 |
+
if return_token_timestamps:
|
2164 |
+
timestamps_list.append(result["token_timestamps"])
|
2165 |
+
|
2166 |
+
sequences = torch.stack(sequences_list, dim=0)
|
2167 |
+
if return_token_timestamps:
|
2168 |
+
token_timestamps = torch.stack(timestamps_list, dim=0)
|
2169 |
+
return sequences, token_timestamps
|
2170 |
+
return sequences
|
2171 |
+
|
2172 |
+
for current_segment_list in current_segments:
|
2173 |
+
if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
|
2174 |
+
sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
|
2175 |
+
if return_token_timestamps:
|
2176 |
+
token_timestamps = torch.cat(
|
2177 |
+
[d["result"]["token_timestamps"][d["idxs"][0] : d["idxs"][1]] for d in current_segment_list],
|
2178 |
+
dim=-1,
|
2179 |
+
)
|
2180 |
+
|
2181 |
+
if cut_off_length is not None:
|
2182 |
+
sequence = sequence[-cut_off_length:]
|
2183 |
+
if return_token_timestamps:
|
2184 |
+
token_timestamps = token_timestamps[-cut_off_length:]
|
2185 |
+
|
2186 |
+
if bos_token_tensor is not None:
|
2187 |
+
sequence = torch.cat([bos_token_tensor, sequence])
|
2188 |
+
if return_token_timestamps:
|
2189 |
+
token_timestamps = torch.cat(
|
2190 |
+
[torch.ones_like(bos_token_tensor, device=device) * 0.0, token_timestamps]
|
2191 |
+
)
|
2192 |
+
sequences.append(sequence)
|
2193 |
+
if return_token_timestamps:
|
2194 |
+
token_timestamps_list.append(token_timestamps)
|
2195 |
+
max_total_length = max(max_total_length, len(sequences[-1]))
|
2196 |
+
elif bos_token_tensor is not None:
|
2197 |
+
sequences.append(bos_token_tensor)
|
2198 |
+
if return_token_timestamps:
|
2199 |
+
token_timestamps_list.append(torch.ones_like(bos_token_tensor, device=device) * 0.0)
|
2200 |
+
else:
|
2201 |
+
sequences.append(torch.tensor([], device=device))
|
2202 |
+
if return_token_timestamps:
|
2203 |
+
token_timestamps_list.append(torch.tensor([], device=device))
|
2204 |
+
|
2205 |
+
max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
|
2206 |
+
for i in range(len(current_segments)):
|
2207 |
+
pad_length = max_total_length - len(sequences[i])
|
2208 |
+
pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
|
2209 |
+
|
2210 |
+
sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
|
2211 |
+
if return_token_timestamps:
|
2212 |
+
token_timestamps_list[i] = F.pad(
|
2213 |
+
token_timestamps_list[i],
|
2214 |
+
pad=pad,
|
2215 |
+
value=token_timestamps_list[i][-1] if len(token_timestamps_list[i]) > 0 else 0.0,
|
2216 |
+
)
|
2217 |
+
|
2218 |
+
sequences = torch.stack(sequences, dim=0)
|
2219 |
+
|
2220 |
+
if return_token_timestamps:
|
2221 |
+
token_timestamps = torch.stack(token_timestamps_list, dim=0)
|
2222 |
+
return sequences, token_timestamps
|
2223 |
+
else:
|
2224 |
+
return sequences
|
2225 |
+
|
2226 |
+
padded_outputs = _pad_to_max_length(
|
2227 |
+
current_segments=final_segments,
|
2228 |
+
pad_token_id=generation_config.pad_token_id,
|
2229 |
+
device=self.device,
|
2230 |
+
padding_side="right",
|
2231 |
+
return_token_timestamps=return_token_timestamps,
|
2232 |
+
force_unique_generate_call=force_unique_generate_call,
|
2233 |
+
)
|
2234 |
+
|
2235 |
+
if return_dict_in_generate and generation_config.return_dict_in_generate:
|
2236 |
+
logger.warning_once(
|
2237 |
+
"You have passed `return_dict_in_generate=True` and `return_timestamps=True`, this automatically sets `return_segments=True` to access the resuls of the underlying calls to GenerationMixin's generate in the returned `segments`."
|
2238 |
+
)
|
2239 |
+
return_segments = True
|
2240 |
+
elif not return_segments and not return_token_timestamps:
|
2241 |
+
return padded_outputs
|
2242 |
+
|
2243 |
+
if return_token_timestamps:
|
2244 |
+
sequences, token_timestamps = padded_outputs
|
2245 |
+
outputs = {
|
2246 |
+
"sequences": sequences,
|
2247 |
+
"token_timestamps": token_timestamps,
|
2248 |
+
}
|
2249 |
+
else:
|
2250 |
+
sequences = padded_outputs
|
2251 |
+
outputs = {
|
2252 |
+
"sequences": sequences,
|
2253 |
+
}
|
2254 |
+
|
2255 |
+
if return_segments:
|
2256 |
+
outputs["segments"] = final_segments
|
2257 |
+
|
2258 |
+
return outputs
|
2259 |
|
2260 |
@add_start_docstrings(
|
2261 |
"The Whisper Model with a language modeling head. Can be used for automatic speech recognition.",
|