|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
import sys |
|
import warnings |
|
from typing import List, Optional, Tuple, Union, TYPE_CHECKING |
|
|
|
import numpy as np |
|
import torch |
|
import tqdm |
|
|
|
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram |
|
from .decoding import DecodingOptions, DecodingResult |
|
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer |
|
from .utils import ( |
|
exact_div, |
|
format_timestamp, |
|
optional_int, |
|
optional_float, |
|
str2bool, |
|
write_txt, |
|
write_vtt, |
|
write_srt, |
|
) |
|
|
|
if TYPE_CHECKING: |
|
from .model import Whisper |
|
|
|
|
|
def transcribe( |
|
model: "Whisper", |
|
audio: Union[str, np.ndarray, torch.Tensor], |
|
*, |
|
verbose: Optional[bool] = None, |
|
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), |
|
compression_ratio_threshold: Optional[float] = 2.4, |
|
logprob_threshold: Optional[float] = -1.0, |
|
no_speech_threshold: Optional[float] = 0.6, |
|
condition_on_previous_text: bool = True, |
|
**decode_options, |
|
): |
|
""" |
|
Transcribe an audio file using Whisper |
|
|
|
Parameters |
|
---------- |
|
model: Whisper |
|
The Whisper model instance |
|
|
|
audio: Union[str, np.ndarray, torch.Tensor] |
|
The path to the audio file to open, or the audio waveform |
|
|
|
verbose: bool |
|
Whether to display the text being decoded to the console. If True, displays all the details, |
|
If False, displays minimal details. If None, does not display anything |
|
|
|
temperature: Union[float, Tuple[float, ...]] |
|
Temperature for sampling. It can be a tuple of temperatures, which will be successively used |
|
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`. |
|
|
|
compression_ratio_threshold: float |
|
If the gzip compression ratio is above this value, treat as failed |
|
|
|
logprob_threshold: float |
|
If the average log probability over sampled tokens is below this value, treat as failed |
|
|
|
no_speech_threshold: float |
|
If the no_speech probability is higher than this value AND the average log probability |
|
over sampled tokens is below `logprob_threshold`, consider the segment as silent |
|
|
|
condition_on_previous_text: bool |
|
if True, the previous output of the model is provided as a prompt for the next window; |
|
disabling may make the text inconsistent across windows, but the model becomes less prone to |
|
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. |
|
|
|
decode_options: dict |
|
Keyword arguments to construct `DecodingOptions` instances |
|
|
|
Returns |
|
------- |
|
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and |
|
the spoken language ("language"), which is detected when `decode_options["language"]` is None. |
|
""" |
|
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 |
|
if model.device == torch.device("cpu"): |
|
if torch.cuda.is_available(): |
|
warnings.warn("Performing inference on CPU when CUDA is available") |
|
if dtype == torch.float16: |
|
warnings.warn("FP16 is not supported on CPU; using FP32 instead") |
|
dtype = torch.float32 |
|
|
|
if dtype == torch.float32: |
|
decode_options["fp16"] = False |
|
|
|
mel = log_mel_spectrogram(audio) |
|
|
|
if decode_options.get("language", None) is None: |
|
if not model.is_multilingual: |
|
decode_options["language"] = "en" |
|
else: |
|
if verbose: |
|
print( |
|
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language" |
|
) |
|
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) |
|
_, probs = model.detect_language(segment) |
|
decode_options["language"] = max(probs, key=probs.get) |
|
if verbose is not None: |
|
print( |
|
f"Detected language: {LANGUAGES[decode_options['language']].title()}" |
|
) |
|
|
|
language = decode_options["language"] |
|
task = decode_options.get("task", "transcribe") |
|
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) |
|
|
|
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: |
|
temperatures = ( |
|
[temperature] if isinstance(temperature, (int, float)) else temperature |
|
) |
|
decode_result = None |
|
|
|
for t in temperatures: |
|
kwargs = {**decode_options} |
|
if t > 0: |
|
|
|
kwargs.pop("beam_size", None) |
|
kwargs.pop("patience", None) |
|
else: |
|
|
|
kwargs.pop("best_of", None) |
|
|
|
options = DecodingOptions(**kwargs, temperature=t) |
|
decode_result = model.decode(segment, options) |
|
|
|
needs_fallback = False |
|
if ( |
|
compression_ratio_threshold is not None |
|
and decode_result.compression_ratio > compression_ratio_threshold |
|
): |
|
needs_fallback = True |
|
if ( |
|
logprob_threshold is not None |
|
and decode_result.avg_logprob < logprob_threshold |
|
): |
|
needs_fallback = True |
|
|
|
if not needs_fallback: |
|
break |
|
|
|
return decode_result |
|
|
|
seek = 0 |
|
input_stride = exact_div( |
|
N_FRAMES, model.dims.n_audio_ctx |
|
) |
|
time_precision = ( |
|
input_stride * HOP_LENGTH / SAMPLE_RATE |
|
) |
|
all_tokens = [] |
|
all_segments = [] |
|
prompt_reset_since = 0 |
|
|
|
initial_prompt = decode_options.pop("initial_prompt", None) or [] |
|
if initial_prompt: |
|
initial_prompt = tokenizer.encode(" " + initial_prompt.strip()) |
|
all_tokens.extend(initial_prompt) |
|
|
|
def add_segment( |
|
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult |
|
): |
|
text = tokenizer.decode( |
|
[token for token in text_tokens if token < tokenizer.eot] |
|
) |
|
if len(text.strip()) == 0: |
|
return |
|
|
|
all_segments.append( |
|
{ |
|
"id": len(all_segments), |
|
"seek": seek, |
|
"start": start, |
|
"end": end, |
|
"text": text, |
|
"tokens": text_tokens.tolist(), |
|
"temperature": result.temperature, |
|
"avg_logprob": result.avg_logprob, |
|
"compression_ratio": result.compression_ratio, |
|
"no_speech_prob": result.no_speech_prob, |
|
} |
|
) |
|
if verbose: |
|
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}\n" |
|
|
|
|
|
sys.stdout.buffer.write( |
|
line.encode(sys.getdefaultencoding(), errors="replace") |
|
) |
|
sys.stdout.flush() |
|
|
|
|
|
num_frames = mel.shape[-1] |
|
previous_seek_value = seek |
|
|
|
with tqdm.tqdm( |
|
total=num_frames, unit="frames", disable=verbose is not False |
|
) as pbar: |
|
while seek < num_frames: |
|
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) |
|
segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype) |
|
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE |
|
|
|
decode_options["prompt"] = all_tokens[prompt_reset_since:] |
|
result: DecodingResult = decode_with_fallback(segment) |
|
tokens = torch.tensor(result.tokens) |
|
|
|
if no_speech_threshold is not None: |
|
|
|
should_skip = result.no_speech_prob > no_speech_threshold |
|
if ( |
|
logprob_threshold is not None |
|
and result.avg_logprob > logprob_threshold |
|
): |
|
|
|
should_skip = False |
|
|
|
if should_skip: |
|
seek += segment.shape[ |
|
-1 |
|
] |
|
continue |
|
|
|
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) |
|
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[ |
|
0 |
|
].add_(1) |
|
if ( |
|
len(consecutive) > 0 |
|
): |
|
last_slice = 0 |
|
for current_slice in consecutive: |
|
sliced_tokens = tokens[last_slice:current_slice] |
|
start_timestamp_position = ( |
|
sliced_tokens[0].item() - tokenizer.timestamp_begin |
|
) |
|
end_timestamp_position = ( |
|
sliced_tokens[-1].item() - tokenizer.timestamp_begin |
|
) |
|
add_segment( |
|
start=timestamp_offset |
|
+ start_timestamp_position * time_precision, |
|
end=timestamp_offset + end_timestamp_position * time_precision, |
|
text_tokens=sliced_tokens[1:-1], |
|
result=result, |
|
) |
|
last_slice = current_slice |
|
last_timestamp_position = ( |
|
tokens[last_slice - 1].item() - tokenizer.timestamp_begin |
|
) |
|
seek += last_timestamp_position * input_stride |
|
all_tokens.extend(tokens[: last_slice + 1].tolist()) |
|
else: |
|
duration = segment_duration |
|
timestamps = tokens[timestamp_tokens.nonzero().flatten()] |
|
if ( |
|
len(timestamps) > 0 |
|
and timestamps[-1].item() != tokenizer.timestamp_begin |
|
): |
|
|
|
|
|
last_timestamp_position = ( |
|
timestamps[-1].item() - tokenizer.timestamp_begin |
|
) |
|
duration = last_timestamp_position * time_precision |
|
|
|
add_segment( |
|
start=timestamp_offset, |
|
end=timestamp_offset + duration, |
|
text_tokens=tokens, |
|
result=result, |
|
) |
|
|
|
seek += segment.shape[-1] |
|
all_tokens.extend(tokens.tolist()) |
|
|
|
if not condition_on_previous_text or result.temperature > 0.5: |
|
|
|
prompt_reset_since = len(all_tokens) |
|
|
|
|
|
pbar.update(min(num_frames, seek) - previous_seek_value) |
|
previous_seek_value = seek |
|
|
|
return dict( |
|
text=tokenizer.decode(all_tokens[len(initial_prompt) :]), |
|
segments=all_segments, |
|
language=language, |
|
) |
|
|
|
|
|
def cli(): |
|
from . import available_models |
|
|
|
parser = argparse.ArgumentParser( |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
) |
|
parser.add_argument( |
|
"audio", nargs="+", type=str, help="audio file(s) to transcribe" |
|
) |
|
parser.add_argument( |
|
"--model", |
|
default="small", |
|
choices=available_models(), |
|
help="name of the Whisper model to use", |
|
) |
|
parser.add_argument( |
|
"--model_dir", |
|
type=str, |
|
default=None, |
|
help="the path to save model files; uses ~/.cache/whisper by default", |
|
) |
|
parser.add_argument( |
|
"--device", |
|
default="cuda" if torch.cuda.is_available() else "cpu", |
|
help="device to use for PyTorch inference", |
|
) |
|
parser.add_argument( |
|
"--output_dir", |
|
"-o", |
|
type=str, |
|
default=".", |
|
help="directory to save the outputs", |
|
) |
|
parser.add_argument( |
|
"--verbose", |
|
type=str2bool, |
|
default=True, |
|
help="whether to print out the progress and debug messages", |
|
) |
|
|
|
parser.add_argument( |
|
"--task", |
|
type=str, |
|
default="transcribe", |
|
choices=["transcribe", "translate"], |
|
help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')", |
|
) |
|
parser.add_argument( |
|
"--language", |
|
type=str, |
|
default=None, |
|
choices=sorted(LANGUAGES.keys()) |
|
+ sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), |
|
help="language spoken in the audio, specify None to perform language detection", |
|
) |
|
|
|
parser.add_argument( |
|
"--temperature", type=float, default=0, help="temperature to use for sampling" |
|
) |
|
parser.add_argument( |
|
"--best_of", |
|
type=optional_int, |
|
default=5, |
|
help="number of candidates when sampling with non-zero temperature", |
|
) |
|
parser.add_argument( |
|
"--beam_size", |
|
type=optional_int, |
|
default=5, |
|
help="number of beams in beam search, only applicable when temperature is zero", |
|
) |
|
parser.add_argument( |
|
"--patience", |
|
type=float, |
|
default=None, |
|
help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search", |
|
) |
|
parser.add_argument( |
|
"--length_penalty", |
|
type=float, |
|
default=None, |
|
help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default", |
|
) |
|
|
|
parser.add_argument( |
|
"--suppress_tokens", |
|
type=str, |
|
default="-1", |
|
help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations", |
|
) |
|
parser.add_argument( |
|
"--initial_prompt", |
|
type=str, |
|
default=None, |
|
help="optional text to provide as a prompt for the first window.", |
|
) |
|
parser.add_argument( |
|
"--condition_on_previous_text", |
|
type=str2bool, |
|
default=True, |
|
help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop", |
|
) |
|
parser.add_argument( |
|
"--fp16", |
|
type=str2bool, |
|
default=True, |
|
help="whether to perform inference in fp16; True by default", |
|
) |
|
|
|
parser.add_argument( |
|
"--temperature_increment_on_fallback", |
|
type=optional_float, |
|
default=0.2, |
|
help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below", |
|
) |
|
parser.add_argument( |
|
"--compression_ratio_threshold", |
|
type=optional_float, |
|
default=2.4, |
|
help="if the gzip compression ratio is higher than this value, treat the decoding as failed", |
|
) |
|
parser.add_argument( |
|
"--logprob_threshold", |
|
type=optional_float, |
|
default=-1.0, |
|
help="if the average log probability is lower than this value, treat the decoding as failed", |
|
) |
|
parser.add_argument( |
|
"--no_speech_threshold", |
|
type=optional_float, |
|
default=0.6, |
|
help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence", |
|
) |
|
parser.add_argument( |
|
"--threads", |
|
type=optional_int, |
|
default=0, |
|
help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS", |
|
) |
|
|
|
args = parser.parse_args().__dict__ |
|
model_name: str = args.pop("model") |
|
model_dir: str = args.pop("model_dir") |
|
output_dir: str = args.pop("output_dir") |
|
device: str = args.pop("device") |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
if model_name.endswith(".en") and args["language"] not in {"en", "English"}: |
|
if args["language"] is not None: |
|
warnings.warn( |
|
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead." |
|
) |
|
args["language"] = "en" |
|
|
|
temperature = args.pop("temperature") |
|
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback") |
|
if temperature_increment_on_fallback is not None: |
|
temperature = tuple( |
|
np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback) |
|
) |
|
else: |
|
temperature = [temperature] |
|
|
|
threads = args.pop("threads") |
|
if threads > 0: |
|
torch.set_num_threads(threads) |
|
|
|
from . import load_model |
|
|
|
model = load_model(model_name, device=device, download_root=model_dir) |
|
|
|
for audio_path in args.pop("audio"): |
|
result = transcribe(model, audio_path, temperature=temperature, **args) |
|
|
|
audio_basename = os.path.basename(audio_path) |
|
|
|
|
|
with open( |
|
os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8" |
|
) as txt: |
|
write_txt(result["segments"], file=txt) |
|
|
|
|
|
with open( |
|
os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8" |
|
) as vtt: |
|
write_vtt(result["segments"], file=vtt) |
|
|
|
|
|
with open( |
|
os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8" |
|
) as srt: |
|
write_srt(result["segments"], file=srt) |
|
|
|
|
|
if __name__ == "__main__": |
|
cli() |
|
|