asr / main.py
HoneyTian's picture
update
e67d20d
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
from collections import defaultdict
from datetime import datetime
import functools
import logging
import os
from pathlib import Path
import platform
import time
import tempfile
import hashlib
from project_settings import project_path, log_directory
import log
log.setup(log_directory=log_directory)
import gradio as gr
import torch
import torchaudio
from toolbox.k2_sherpa.examples import examples
from toolbox.k2_sherpa import decode, nn_models
from toolbox.k2_sherpa.utils import audio_convert
main_logger = logging.getLogger("main")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained_model_dir",
default=(project_path / "pretrained_models").as_posix(),
type=str
)
args = parser.parse_args()
return args
def update_model_dropdown(language: str):
if language not in nn_models.model_map.keys():
raise ValueError(f"Unsupported language: {language}")
choices = nn_models.model_map[language]
choices = [c["repo_id"] for c in choices]
return gr.Dropdown(
choices=choices,
value=choices[0],
interactive=True,
)
def build_html_output(s: str, style: str = "result_item_success"):
return f"""
<div class='result'>
<div class='result_item {style}'>
{s}
</div>
</div>
"""
def md5_encrypt(text: str) -> str:
"""output str length: 32. """
md = hashlib.md5()
md.update(text.encode())
result = md.hexdigest()
return result
@torch.no_grad()
def process(
language: str,
repo_id: str,
decoding_method: str,
num_active_paths: int,
add_punctuation: str,
in_filename: str,
pretrained_model_dir: Path,
):
main_logger.info("language: {}".format(language))
main_logger.info("repo_id: {}".format(repo_id))
main_logger.info("decoding_method: {}".format(decoding_method))
main_logger.info("num_active_paths: {}".format(num_active_paths))
main_logger.info("in_filename: {}".format(in_filename))
# audio convert
in_filename = Path(in_filename)
out_filename = Path(tempfile.gettempdir()) / "asr" / in_filename.name
out_filename.parent.mkdir(parents=True, exist_ok=True)
audio_convert(in_filename=in_filename.as_posix(),
out_filename=out_filename.as_posix(),
)
# model settings
m_list = nn_models.model_map.get(language)
if m_list is None:
raise AssertionError("language invalid: {}".format(language))
m_dict = None
for m in m_list:
if m["repo_id"] == repo_id:
m_dict = m
if m_dict is None:
raise AssertionError("repo_id invalid: {}".format(repo_id))
# local_model_dir
repo_id: Path = Path(repo_id)
if len(repo_id.parts) == 1:
repo_name = repo_id.parts[-1]
if len(repo_name) > 40:
repo_name = md5_encrypt(repo_name)
# repo_name = repo_name[:40]
folder = repo_name
elif len(repo_id.parts) == 2:
repo_supplier = repo_id.parts[-2]
repo_name = repo_id.parts[-1]
if len(repo_name) > 40:
repo_name = md5_encrypt(repo_name)
# repo_name = repo_name[:40]
folder = "{}/{}".format(repo_supplier, repo_name)
else:
raise AssertionError("repo_id parts count invalid: {}".format(len(repo_id.parts)))
local_model_dir = pretrained_model_dir / "huggingface" / folder
# load recognizer
recognizer = nn_models.load_recognizer(
local_model_dir=local_model_dir,
decoding_method=decoding_method,
num_active_paths=num_active_paths,
**m_dict
)
# transcribe
now = datetime.now()
date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
logging.info(f"Started at {date_time}")
start = time.time()
text = decode.decode_by_recognizer(recognizer=recognizer,
filename=out_filename.as_posix(),
)
# load_punctuation_model
if add_punctuation == "Yes":
punctuation_repo_id = "csukuangfj/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12"
local_model_dir = pretrained_model_dir / "huggingface" / md5_encrypt(punctuation_repo_id)
punctuation_model = nn_models.load_punctuation_model(
local_model_dir=local_model_dir,
repo_id=punctuation_repo_id,
nn_model_file="model.onnx",
nn_model_file_sub_folder=".",
)
text = punctuation_model.add_punctuation(text)
# statistics
date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
end = time.time()
metadata = torchaudio.info(out_filename.as_posix())
duration = metadata.num_frames / 16000
rtf = (end - start) / duration
main_logger.info(f"Finished at {date_time} s. Elapsed: {end - start: .3f} s")
info = f"""
Wave duration : {duration: .3f} s <br/>
Processing time: {end - start: .3f} s <br/>
RTF: {end - start: .3f}/{duration: .3f} = {rtf:.3f} <br/>
"""
main_logger.info(info)
main_logger.info(f"\nrepo_id: {repo_id}\nhyp: {text}")
return text, build_html_output(info)
def process_uploaded_file(language: str,
repo_id: str,
decoding_method: str,
num_active_paths: int,
add_punctuation: str,
in_filename: str,
pretrained_model_dir: Path,
):
if in_filename is None or in_filename == "":
return "", build_html_output(
"Please first upload a file and then click "
'the button "submit for recognition"',
"result_item_error",
)
main_logger.info(f"Processing uploaded file: {in_filename}")
try:
return process(
in_filename=in_filename,
language=language,
repo_id=repo_id,
decoding_method=decoding_method,
num_active_paths=num_active_paths,
add_punctuation=add_punctuation,
pretrained_model_dir=pretrained_model_dir,
)
except Exception as e:
msg = "transcribe error: {}".format(str(e))
main_logger.info(msg)
return "", build_html_output(msg, "result_item_error")
# css style is copied from
# https://huggingface.co/spaces/alphacep/asr/blob/main/app.py#L113
css = """
.result {display:flex;flex-direction:column}
.result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%}
.result_item_success {background-color:mediumaquamarine;color:white;align-self:start}
.result_item_error {background-color:#ff7070;color:white;align-self:start}
"""
def main():
args = get_args()
pretrained_model_dir = Path(args.pretrained_model_dir)
pretrained_model_dir.mkdir(exist_ok=True)
process_uploaded_file_ = functools.partial(
process_uploaded_file,
pretrained_model_dir=pretrained_model_dir,
)
title = "# Automatic Speech Recognition with Next-gen Kaldi"
language_choices = list(nn_models.model_map.keys())
language_to_models = defaultdict(list)
for k, v in nn_models.model_map.items():
for m in v:
repo_id = m["repo_id"]
language_to_models[k].append(repo_id)
# blocks
with gr.Blocks(css=css) as blocks:
gr.Markdown(value=title)
with gr.Tabs():
with gr.TabItem("Upload from disk"):
language_radio = gr.Radio(
label="Language",
choices=language_choices,
value=language_choices[0],
)
model_dropdown = gr.Dropdown(
choices=language_to_models[language_choices[0]],
label="Select a model",
value=language_to_models[language_choices[0]][0],
allow_custom_value=True
)
decoding_method_radio = gr.Radio(
label="Decoding method",
choices=["greedy_search", "modified_beam_search"],
value="greedy_search",
)
num_active_paths_slider = gr.Slider(
minimum=1,
value=4,
step=1,
label="Number of active paths for modified_beam_search",
)
punct_radio = gr.Radio(
label="Whether to add punctuation (Only for Chinese and English)",
choices=["Yes", "No"],
value="Yes",
)
uploaded_file = gr.Audio(
sources=["upload"],
type="filepath",
label="Upload from disk",
)
upload_button = gr.Button("Submit for recognition")
uploaded_output = gr.Textbox(label="Recognized speech from uploaded file")
uploaded_html_info = gr.HTML(label="Info")
gr.Examples(
examples=examples,
inputs=[
language_radio,
model_dropdown,
decoding_method_radio,
num_active_paths_slider,
punct_radio,
uploaded_file,
],
outputs=[uploaded_output, uploaded_html_info],
fn=process_uploaded_file_,
)
upload_button.click(
process_uploaded_file_,
inputs=[
language_radio,
model_dropdown,
decoding_method_radio,
num_active_paths_slider,
punct_radio,
uploaded_file,
],
outputs=[uploaded_output, uploaded_html_info],
)
language_radio.change(
update_model_dropdown,
inputs=language_radio,
outputs=model_dropdown,
)
blocks.queue().launch(
share=False if platform.system() == "Windows" else False,
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
server_port=7860
)
return
if __name__ == "__main__":
main()