Spaces:
Runtime error
Runtime error
Adding support for word timestamps
Browse files- app.py +28 -12
- cli.py +14 -2
- config.json5 +10 -1
- src/config.py +11 -1
- src/utils.py +117 -8
- src/vad.py +8 -0
- src/whisper/whisperContainer.py +3 -2
app.py
CHANGED
|
@@ -100,13 +100,17 @@ class WhisperTranscriber:
|
|
| 100 |
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
| 101 |
initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
|
| 102 |
condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
|
| 103 |
-
compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
return self.transcribe_webui_full_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
| 106 |
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
| 107 |
initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
|
| 108 |
condition_on_previous_text, fp16, temperature_increment_on_fallback,
|
| 109 |
-
compression_ratio_threshold, logprob_threshold, no_speech_threshold
|
|
|
|
| 110 |
|
| 111 |
# Entry function for the full tab with progress
|
| 112 |
def transcribe_webui_full_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
|
@@ -114,6 +118,9 @@ class WhisperTranscriber:
|
|
| 114 |
initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
|
| 115 |
condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
|
| 116 |
compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
|
|
|
|
|
|
|
|
|
|
| 117 |
progress=gr.Progress()):
|
| 118 |
|
| 119 |
# Handle temperature_increment_on_fallback
|
|
@@ -128,13 +135,15 @@ class WhisperTranscriber:
|
|
| 128 |
initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
|
| 129 |
condition_on_previous_text=condition_on_previous_text, fp16=fp16,
|
| 130 |
compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
|
|
|
|
| 131 |
progress=progress)
|
| 132 |
|
| 133 |
def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
| 134 |
-
vadOptions: VadOptions, progress: gr.Progress = None,
|
|
|
|
| 135 |
try:
|
| 136 |
sources = self.__get_source(urlData, multipleFiles, microphoneData)
|
| 137 |
-
|
| 138 |
try:
|
| 139 |
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
|
| 140 |
selectedModel = modelName if modelName is not None else "base"
|
|
@@ -185,7 +194,7 @@ class WhisperTranscriber:
|
|
| 185 |
# Update progress
|
| 186 |
current_progress += source_audio_duration
|
| 187 |
|
| 188 |
-
source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory)
|
| 189 |
|
| 190 |
if len(sources) > 1:
|
| 191 |
# Add new line separators
|
|
@@ -359,7 +368,7 @@ class WhisperTranscriber:
|
|
| 359 |
|
| 360 |
return config
|
| 361 |
|
| 362 |
-
def write_result(self, result: dict, source_name: str, output_dir: str):
|
| 363 |
if not os.path.exists(output_dir):
|
| 364 |
os.makedirs(output_dir)
|
| 365 |
|
|
@@ -368,8 +377,8 @@ class WhisperTranscriber:
|
|
| 368 |
languageMaxLineWidth = self.__get_max_line_width(language)
|
| 369 |
|
| 370 |
print("Max line width " + str(languageMaxLineWidth))
|
| 371 |
-
vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth)
|
| 372 |
-
srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth)
|
| 373 |
|
| 374 |
output_files = []
|
| 375 |
output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
|
|
@@ -394,13 +403,13 @@ class WhisperTranscriber:
|
|
| 394 |
# 80 latin characters should fit on a 1080p/720p screen
|
| 395 |
return 80
|
| 396 |
|
| 397 |
-
def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
|
| 398 |
segmentStream = StringIO()
|
| 399 |
|
| 400 |
if format == 'vtt':
|
| 401 |
-
write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
|
| 402 |
elif format == 'srt':
|
| 403 |
-
write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
|
| 404 |
else:
|
| 405 |
raise Exception("Unknown format " + format)
|
| 406 |
|
|
@@ -501,7 +510,14 @@ def create_ui(app_config: ApplicationConfig):
|
|
| 501 |
gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
|
| 502 |
gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
|
| 503 |
gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
|
| 504 |
-
gr.Number(label="No speech threshold", value=app_config.no_speech_threshold)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
], outputs=[
|
| 506 |
gr.File(label="Download"),
|
| 507 |
gr.Text(label="Transcription"),
|
|
|
|
| 100 |
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
| 101 |
initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
|
| 102 |
condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
|
| 103 |
+
compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
|
| 104 |
+
# Word timestamps
|
| 105 |
+
word_timestamps: bool, prepend_punctuations: str,
|
| 106 |
+
append_punctuations: str, highlight_words: bool = False):
|
| 107 |
|
| 108 |
return self.transcribe_webui_full_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
| 109 |
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
| 110 |
initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
|
| 111 |
condition_on_previous_text, fp16, temperature_increment_on_fallback,
|
| 112 |
+
compression_ratio_threshold, logprob_threshold, no_speech_threshold,
|
| 113 |
+
word_timestamps, prepend_punctuations, append_punctuations, highlight_words)
|
| 114 |
|
| 115 |
# Entry function for the full tab with progress
|
| 116 |
def transcribe_webui_full_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
|
|
|
| 118 |
initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
|
| 119 |
condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
|
| 120 |
compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
|
| 121 |
+
# Word timestamps
|
| 122 |
+
word_timestamps: bool, prepend_punctuations: str,
|
| 123 |
+
append_punctuations: str, highlight_words: bool = False,
|
| 124 |
progress=gr.Progress()):
|
| 125 |
|
| 126 |
# Handle temperature_increment_on_fallback
|
|
|
|
| 135 |
initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
|
| 136 |
condition_on_previous_text=condition_on_previous_text, fp16=fp16,
|
| 137 |
compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
|
| 138 |
+
word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
|
| 139 |
progress=progress)
|
| 140 |
|
| 141 |
def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
| 142 |
+
vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False,
|
| 143 |
+
**decodeOptions: dict):
|
| 144 |
try:
|
| 145 |
sources = self.__get_source(urlData, multipleFiles, microphoneData)
|
| 146 |
+
|
| 147 |
try:
|
| 148 |
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
|
| 149 |
selectedModel = modelName if modelName is not None else "base"
|
|
|
|
| 194 |
# Update progress
|
| 195 |
current_progress += source_audio_duration
|
| 196 |
|
| 197 |
+
source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory, highlight_words)
|
| 198 |
|
| 199 |
if len(sources) > 1:
|
| 200 |
# Add new line separators
|
|
|
|
| 368 |
|
| 369 |
return config
|
| 370 |
|
| 371 |
+
def write_result(self, result: dict, source_name: str, output_dir: str, highlight_words: bool = False):
|
| 372 |
if not os.path.exists(output_dir):
|
| 373 |
os.makedirs(output_dir)
|
| 374 |
|
|
|
|
| 377 |
languageMaxLineWidth = self.__get_max_line_width(language)
|
| 378 |
|
| 379 |
print("Max line width " + str(languageMaxLineWidth))
|
| 380 |
+
vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth, highlight_words=highlight_words)
|
| 381 |
+
srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth, highlight_words=highlight_words)
|
| 382 |
|
| 383 |
output_files = []
|
| 384 |
output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
|
|
|
|
| 403 |
# 80 latin characters should fit on a 1080p/720p screen
|
| 404 |
return 80
|
| 405 |
|
| 406 |
+
def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int, highlight_words: bool = False) -> str:
|
| 407 |
segmentStream = StringIO()
|
| 408 |
|
| 409 |
if format == 'vtt':
|
| 410 |
+
write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
|
| 411 |
elif format == 'srt':
|
| 412 |
+
write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
|
| 413 |
else:
|
| 414 |
raise Exception("Unknown format " + format)
|
| 415 |
|
|
|
|
| 510 |
gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
|
| 511 |
gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
|
| 512 |
gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
|
| 513 |
+
gr.Number(label="No speech threshold", value=app_config.no_speech_threshold),
|
| 514 |
+
|
| 515 |
+
# Word timestamps
|
| 516 |
+
gr.Checkbox(label="Word Timestamps", value=app_config.word_timestamps),
|
| 517 |
+
gr.Text(label="Word Timestamps - Prepend Punctuations", value=app_config.prepend_punctuations),
|
| 518 |
+
gr.Text(label="Word Timestamps - Append Punctuations", value=app_config.append_punctuations),
|
| 519 |
+
gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words),
|
| 520 |
+
|
| 521 |
], outputs=[
|
| 522 |
gr.File(label="Download"),
|
| 523 |
gr.Text(label="Transcription"),
|
cli.py
CHANGED
|
@@ -95,6 +95,17 @@ def cli():
|
|
| 95 |
parser.add_argument("--no_speech_threshold", type=optional_float, default=app_config.no_speech_threshold, \
|
| 96 |
help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
args = parser.parse_args().__dict__
|
| 99 |
model_name: str = args.pop("model")
|
| 100 |
model_dir: str = args.pop("model_dir")
|
|
@@ -126,6 +137,7 @@ def cli():
|
|
| 126 |
auto_parallel = args.pop("auto_parallel")
|
| 127 |
|
| 128 |
compute_type = args.pop("compute_type")
|
|
|
|
| 129 |
|
| 130 |
transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
|
| 131 |
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
|
@@ -133,7 +145,7 @@ def cli():
|
|
| 133 |
|
| 134 |
model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
|
| 135 |
device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
|
| 136 |
-
|
| 137 |
if (transcriber._has_parallel_devices()):
|
| 138 |
print("Using parallel devices:", transcriber.parallel_device_list)
|
| 139 |
|
|
@@ -158,7 +170,7 @@ def cli():
|
|
| 158 |
|
| 159 |
result = transcriber.transcribe_file(model, source_path, temperature=temperature, vadOptions=vadOptions, **args)
|
| 160 |
|
| 161 |
-
transcriber.write_result(result, source_name, output_dir)
|
| 162 |
|
| 163 |
transcriber.close()
|
| 164 |
|
|
|
|
| 95 |
parser.add_argument("--no_speech_threshold", type=optional_float, default=app_config.no_speech_threshold, \
|
| 96 |
help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
| 97 |
|
| 98 |
+
parser.add_argument("--word_timestamps", type=str2bool, default=app_config.word_timestamps,
|
| 99 |
+
help="(experimental) extract word-level timestamps and refine the results based on them")
|
| 100 |
+
parser.add_argument("--prepend_punctuations", type=str, default=app_config.prepend_punctuations,
|
| 101 |
+
help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
| 102 |
+
parser.add_argument("--append_punctuations", type=str, default=app_config.append_punctuations,
|
| 103 |
+
help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
| 104 |
+
parser.add_argument("--highlight_words", type=str2bool, default=app_config.highlight_words,
|
| 105 |
+
help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
| 106 |
+
parser.add_argument("--threads", type=optional_int, default=0,
|
| 107 |
+
help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
| 108 |
+
|
| 109 |
args = parser.parse_args().__dict__
|
| 110 |
model_name: str = args.pop("model")
|
| 111 |
model_dir: str = args.pop("model_dir")
|
|
|
|
| 137 |
auto_parallel = args.pop("auto_parallel")
|
| 138 |
|
| 139 |
compute_type = args.pop("compute_type")
|
| 140 |
+
highlight_words = args.pop("highlight_words")
|
| 141 |
|
| 142 |
transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
|
| 143 |
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
|
|
|
| 145 |
|
| 146 |
model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
|
| 147 |
device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
|
| 148 |
+
|
| 149 |
if (transcriber._has_parallel_devices()):
|
| 150 |
print("Using parallel devices:", transcriber.parallel_device_list)
|
| 151 |
|
|
|
|
| 170 |
|
| 171 |
result = transcriber.transcribe_file(model, source_path, temperature=temperature, vadOptions=vadOptions, **args)
|
| 172 |
|
| 173 |
+
transcriber.write_result(result, source_name, output_dir, highlight_words)
|
| 174 |
|
| 175 |
transcriber.close()
|
| 176 |
|
config.json5
CHANGED
|
@@ -128,5 +128,14 @@
|
|
| 128 |
// If the average log probability is lower than this value, treat the decoding as failed
|
| 129 |
"logprob_threshold": -1.0,
|
| 130 |
// If the probability of the <no-speech> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence
|
| 131 |
-
"no_speech_threshold": 0.6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
}
|
|
|
|
| 128 |
// If the average log probability is lower than this value, treat the decoding as failed
|
| 129 |
"logprob_threshold": -1.0,
|
| 130 |
// If the probability of the <no-speech> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence
|
| 131 |
+
"no_speech_threshold": 0.6,
|
| 132 |
+
|
| 133 |
+
// (experimental) extract word-level timestamps and refine the results based on them
|
| 134 |
+
"word_timestamps": false,
|
| 135 |
+
// if word_timestamps is True, merge these punctuation symbols with the next word
|
| 136 |
+
"prepend_punctuations": "\"\'“¿([{-",
|
| 137 |
+
// if word_timestamps is True, merge these punctuation symbols with the previous word
|
| 138 |
+
"append_punctuations": "\"\'.。,,!!??::”)]}、",
|
| 139 |
+
// (requires --word_timestamps True) underline each word as it is spoken in srt and vtt
|
| 140 |
+
"highlight_words": false,
|
| 141 |
}
|
src/config.py
CHANGED
|
@@ -58,7 +58,11 @@ class ApplicationConfig:
|
|
| 58 |
condition_on_previous_text: bool = True, fp16: bool = True,
|
| 59 |
compute_type: str = "float16",
|
| 60 |
temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
|
| 61 |
-
logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
self.models = models
|
| 64 |
|
|
@@ -104,6 +108,12 @@ class ApplicationConfig:
|
|
| 104 |
self.logprob_threshold = logprob_threshold
|
| 105 |
self.no_speech_threshold = no_speech_threshold
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
def get_model_names(self):
|
| 108 |
return [ x.name for x in self.models ]
|
| 109 |
|
|
|
|
| 58 |
condition_on_previous_text: bool = True, fp16: bool = True,
|
| 59 |
compute_type: str = "float16",
|
| 60 |
temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
|
| 61 |
+
logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6,
|
| 62 |
+
# Word timestamp settings
|
| 63 |
+
word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
|
| 64 |
+
append_punctuations: str = "\"\'.。,,!!??::”)]}、",
|
| 65 |
+
highlight_words: bool = False):
|
| 66 |
|
| 67 |
self.models = models
|
| 68 |
|
|
|
|
| 108 |
self.logprob_threshold = logprob_threshold
|
| 109 |
self.no_speech_threshold = no_speech_threshold
|
| 110 |
|
| 111 |
+
# Word timestamp settings
|
| 112 |
+
self.word_timestamps = word_timestamps
|
| 113 |
+
self.prepend_punctuations = prepend_punctuations
|
| 114 |
+
self.append_punctuations = append_punctuations
|
| 115 |
+
self.highlight_words = highlight_words
|
| 116 |
+
|
| 117 |
def get_model_names(self):
|
| 118 |
return [ x.name for x in self.models ]
|
| 119 |
|
src/utils.py
CHANGED
|
@@ -3,7 +3,7 @@ import unicodedata
|
|
| 3 |
import re
|
| 4 |
|
| 5 |
import zlib
|
| 6 |
-
from typing import Iterator, TextIO
|
| 7 |
import tqdm
|
| 8 |
|
| 9 |
import urllib3
|
|
@@ -56,10 +56,14 @@ def write_txt(transcript: Iterator[dict], file: TextIO):
|
|
| 56 |
print(segment['text'].strip(), file=file, flush=True)
|
| 57 |
|
| 58 |
|
| 59 |
-
def write_vtt(transcript: Iterator[dict], file: TextIO,
|
|
|
|
|
|
|
|
|
|
| 60 |
print("WEBVTT\n", file=file)
|
| 61 |
-
|
| 62 |
-
|
|
|
|
| 63 |
|
| 64 |
print(
|
| 65 |
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
|
@@ -68,8 +72,8 @@ def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
|
|
| 68 |
flush=True,
|
| 69 |
)
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
"""
|
| 74 |
Write a transcript to a file in SRT format.
|
| 75 |
Example usage:
|
|
@@ -81,8 +85,10 @@ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
|
|
| 81 |
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
| 82 |
write_srt(result["segments"], file=srt)
|
| 83 |
"""
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
| 86 |
|
| 87 |
# write srt lines
|
| 88 |
print(
|
|
@@ -94,6 +100,109 @@ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
|
|
| 94 |
flush=True,
|
| 95 |
)
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
def process_text(text: str, maxLineWidth=None):
|
| 98 |
if (maxLineWidth is None or maxLineWidth < 0):
|
| 99 |
return text
|
|
|
|
| 3 |
import re
|
| 4 |
|
| 5 |
import zlib
|
| 6 |
+
from typing import Iterator, TextIO, Union
|
| 7 |
import tqdm
|
| 8 |
|
| 9 |
import urllib3
|
|
|
|
| 56 |
print(segment['text'].strip(), file=file, flush=True)
|
| 57 |
|
| 58 |
|
| 59 |
+
def write_vtt(transcript: Iterator[dict], file: TextIO,
|
| 60 |
+
maxLineWidth=None, highlight_words: bool = False):
|
| 61 |
+
iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
|
| 62 |
+
|
| 63 |
print("WEBVTT\n", file=file)
|
| 64 |
+
|
| 65 |
+
for segment in iterator:
|
| 66 |
+
text = segment['text'].replace('-->', '->')
|
| 67 |
|
| 68 |
print(
|
| 69 |
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
|
|
|
| 72 |
flush=True,
|
| 73 |
)
|
| 74 |
|
| 75 |
+
def write_srt(transcript: Iterator[dict], file: TextIO,
|
| 76 |
+
maxLineWidth=None, highlight_words: bool = False):
|
| 77 |
"""
|
| 78 |
Write a transcript to a file in SRT format.
|
| 79 |
Example usage:
|
|
|
|
| 85 |
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
| 86 |
write_srt(result["segments"], file=srt)
|
| 87 |
"""
|
| 88 |
+
iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
|
| 89 |
+
|
| 90 |
+
for i, segment in enumerate(iterator, start=1):
|
| 91 |
+
text = segment['text'].replace('-->', '->')
|
| 92 |
|
| 93 |
# write srt lines
|
| 94 |
print(
|
|
|
|
| 100 |
flush=True,
|
| 101 |
)
|
| 102 |
|
| 103 |
+
def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: int = None, highlight_words: bool = False):
|
| 104 |
+
for segment in transcript:
|
| 105 |
+
words = segment.get('words', [])
|
| 106 |
+
|
| 107 |
+
if len(words) == 0:
|
| 108 |
+
# Yield the segment as-is
|
| 109 |
+
if maxLineWidth is None or maxLineWidth < 0:
|
| 110 |
+
yield segment
|
| 111 |
+
|
| 112 |
+
# Yield the segment with processed text
|
| 113 |
+
yield {
|
| 114 |
+
'start': segment['start'],
|
| 115 |
+
'end': segment['end'],
|
| 116 |
+
'text': process_text(segment['text'].strip(), maxLineWidth)
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
subtitle_start = segment['start']
|
| 120 |
+
subtitle_end = segment['end']
|
| 121 |
+
|
| 122 |
+
text_words = [ this_word["word"] for this_word in words ]
|
| 123 |
+
subtitle_text = __join_words(text_words, maxLineWidth)
|
| 124 |
+
|
| 125 |
+
# Iterate over the words in the segment
|
| 126 |
+
if highlight_words:
|
| 127 |
+
last = subtitle_start
|
| 128 |
+
|
| 129 |
+
for i, this_word in enumerate(words):
|
| 130 |
+
start = this_word['start']
|
| 131 |
+
end = this_word['end']
|
| 132 |
+
|
| 133 |
+
if last != start:
|
| 134 |
+
# Display the text up to this point
|
| 135 |
+
yield {
|
| 136 |
+
'start': last,
|
| 137 |
+
'end': start,
|
| 138 |
+
'text': subtitle_text
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
# Display the text with the current word highlighted
|
| 142 |
+
yield {
|
| 143 |
+
'start': start,
|
| 144 |
+
'end': end,
|
| 145 |
+
'text': __join_words(
|
| 146 |
+
[
|
| 147 |
+
{
|
| 148 |
+
"word": re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
| 149 |
+
if j == i
|
| 150 |
+
else word,
|
| 151 |
+
# The HTML tags <u> and </u> are not displayed,
|
| 152 |
+
# # so they should not be counted in the word length
|
| 153 |
+
"length": len(word)
|
| 154 |
+
} for j, word in enumerate(text_words)
|
| 155 |
+
], maxLineWidth)
|
| 156 |
+
}
|
| 157 |
+
last = end
|
| 158 |
+
|
| 159 |
+
if last != subtitle_end:
|
| 160 |
+
# Display the last part of the text
|
| 161 |
+
yield {
|
| 162 |
+
'start': last,
|
| 163 |
+
'end': subtitle_end,
|
| 164 |
+
'text': subtitle_text
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
# Just return the subtitle text
|
| 168 |
+
else:
|
| 169 |
+
yield {
|
| 170 |
+
'start': subtitle_start,
|
| 171 |
+
'end': subtitle_end,
|
| 172 |
+
'text': subtitle_text
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
def __join_words(words: Iterator[Union[str, dict]], maxLineWidth: int = None):
|
| 176 |
+
if maxLineWidth is None or maxLineWidth < 0:
|
| 177 |
+
return " ".join(words)
|
| 178 |
+
|
| 179 |
+
lines = []
|
| 180 |
+
current_line = ""
|
| 181 |
+
current_length = 0
|
| 182 |
+
|
| 183 |
+
for entry in words:
|
| 184 |
+
# Either accept a string or a dict with a 'word' and 'length' field
|
| 185 |
+
if isinstance(entry, dict):
|
| 186 |
+
word = entry['word']
|
| 187 |
+
word_length = entry['length']
|
| 188 |
+
else:
|
| 189 |
+
word = entry
|
| 190 |
+
word_length = len(word)
|
| 191 |
+
|
| 192 |
+
if current_length > 0 and current_length + word_length > maxLineWidth:
|
| 193 |
+
lines.append(current_line)
|
| 194 |
+
current_line = ""
|
| 195 |
+
current_length = 0
|
| 196 |
+
|
| 197 |
+
current_length += word_length
|
| 198 |
+
# The word will be prefixed with a space by Whisper, so we don't need to add one here
|
| 199 |
+
current_line += word
|
| 200 |
+
|
| 201 |
+
if len(current_line) > 0:
|
| 202 |
+
lines.append(current_line)
|
| 203 |
+
|
| 204 |
+
return "\n".join(lines)
|
| 205 |
+
|
| 206 |
def process_text(text: str, maxLineWidth=None):
|
| 207 |
if (maxLineWidth is None or maxLineWidth < 0):
|
| 208 |
return text
|
src/vad.py
CHANGED
|
@@ -404,6 +404,14 @@ class AbstractTranscription(ABC):
|
|
| 404 |
# Add to start and end
|
| 405 |
new_segment['start'] = segment_start + adjust_seconds
|
| 406 |
new_segment['end'] = segment_end + adjust_seconds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
result.append(new_segment)
|
| 408 |
return result
|
| 409 |
|
|
|
|
| 404 |
# Add to start and end
|
| 405 |
new_segment['start'] = segment_start + adjust_seconds
|
| 406 |
new_segment['end'] = segment_end + adjust_seconds
|
| 407 |
+
|
| 408 |
+
# Handle words
|
| 409 |
+
if ('words' in new_segment):
|
| 410 |
+
for word in new_segment['words']:
|
| 411 |
+
# Adjust start and end
|
| 412 |
+
word['start'] = word['start'] + adjust_seconds
|
| 413 |
+
word['end'] = word['end'] + adjust_seconds
|
| 414 |
+
|
| 415 |
result.append(new_segment)
|
| 416 |
return result
|
| 417 |
|
src/whisper/whisperContainer.py
CHANGED
|
@@ -203,8 +203,9 @@ class WhisperCallback(AbstractWhisperCallback):
|
|
| 203 |
|
| 204 |
initial_prompt = self._get_initial_prompt(self.initial_prompt, self.initial_prompt_mode, prompt, segment_index)
|
| 205 |
|
| 206 |
-
|
| 207 |
language=self.language if self.language else detected_language, task=self.task, \
|
| 208 |
initial_prompt=initial_prompt, \
|
| 209 |
**decodeOptions
|
| 210 |
-
)
|
|
|
|
|
|
| 203 |
|
| 204 |
initial_prompt = self._get_initial_prompt(self.initial_prompt, self.initial_prompt_mode, prompt, segment_index)
|
| 205 |
|
| 206 |
+
result = model.transcribe(audio, \
|
| 207 |
language=self.language if self.language else detected_language, task=self.task, \
|
| 208 |
initial_prompt=initial_prompt, \
|
| 209 |
**decodeOptions
|
| 210 |
+
)
|
| 211 |
+
return result
|