Spaces:
Sleeping
Sleeping
update
Browse files
examples/wenet/toolbox_download.py
ADDED
|
File without changes
|
toolbox/k2_sherpa/nn_models.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
# -*- coding: utf-8 -*-
|
| 3 |
from enum import Enum
|
| 4 |
from functools import lru_cache
|
|
|
|
| 5 |
import os
|
| 6 |
import platform
|
| 7 |
from pathlib import Path
|
|
@@ -10,6 +11,8 @@ import huggingface_hub
|
|
| 10 |
import sherpa
|
| 11 |
import sherpa_onnx
|
| 12 |
|
|
|
|
|
|
|
| 13 |
|
| 14 |
class EnumDecodingMethod(Enum):
|
| 15 |
greedy_search = "greedy_search"
|
|
@@ -104,6 +107,7 @@ def download_model(local_model_dir: str,
|
|
| 104 |
repo_id = kwargs["repo_id"]
|
| 105 |
|
| 106 |
if "nn_model_file" in kwargs.keys():
|
|
|
|
| 107 |
_ = huggingface_hub.hf_hub_download(
|
| 108 |
repo_id=repo_id,
|
| 109 |
filename=kwargs["nn_model_file"],
|
|
@@ -112,6 +116,7 @@ def download_model(local_model_dir: str,
|
|
| 112 |
)
|
| 113 |
|
| 114 |
if "encoder_model_file" in kwargs.keys():
|
|
|
|
| 115 |
_ = huggingface_hub.hf_hub_download(
|
| 116 |
repo_id=repo_id,
|
| 117 |
filename=kwargs["encoder_model_file"],
|
|
@@ -120,6 +125,7 @@ def download_model(local_model_dir: str,
|
|
| 120 |
)
|
| 121 |
|
| 122 |
if "decoder_model_file" in kwargs.keys():
|
|
|
|
| 123 |
_ = huggingface_hub.hf_hub_download(
|
| 124 |
repo_id=repo_id,
|
| 125 |
filename=kwargs["decoder_model_file"],
|
|
@@ -128,6 +134,7 @@ def download_model(local_model_dir: str,
|
|
| 128 |
)
|
| 129 |
|
| 130 |
if "joiner_model_file" in kwargs.keys():
|
|
|
|
| 131 |
_ = huggingface_hub.hf_hub_download(
|
| 132 |
repo_id=repo_id,
|
| 133 |
filename=kwargs["joiner_model_file"],
|
|
@@ -136,6 +143,7 @@ def download_model(local_model_dir: str,
|
|
| 136 |
)
|
| 137 |
|
| 138 |
if "tokens_file" in kwargs.keys():
|
|
|
|
| 139 |
_ = huggingface_hub.hf_hub_download(
|
| 140 |
repo_id=repo_id,
|
| 141 |
filename=kwargs["tokens_file"],
|
|
@@ -158,6 +166,9 @@ def load_sherpa_offline_recognizer(nn_model_file: str,
|
|
| 158 |
feat_config.fbank_opts.mel_opts.num_bins = num_mel_bins
|
| 159 |
feat_config.fbank_opts.frame_opts.dither = frame_dither
|
| 160 |
|
|
|
|
|
|
|
|
|
|
| 161 |
config = sherpa.OfflineRecognizerConfig(
|
| 162 |
nn_model=nn_model_file,
|
| 163 |
tokens=tokens_file,
|
|
@@ -220,7 +231,7 @@ def load_recognizer(local_model_dir: Path,
|
|
| 220 |
num_active_paths: int = 4,
|
| 221 |
**kwargs
|
| 222 |
):
|
| 223 |
-
if not
|
| 224 |
download_model(
|
| 225 |
local_model_dir=local_model_dir.as_posix(),
|
| 226 |
**kwargs,
|
|
|
|
| 2 |
# -*- coding: utf-8 -*-
|
| 3 |
from enum import Enum
|
| 4 |
from functools import lru_cache
|
| 5 |
+
import logging
|
| 6 |
import os
|
| 7 |
import platform
|
| 8 |
from pathlib import Path
|
|
|
|
| 11 |
import sherpa
|
| 12 |
import sherpa_onnx
|
| 13 |
|
| 14 |
+
main_logger = logging.getLogger("main")
|
| 15 |
+
|
| 16 |
|
| 17 |
class EnumDecodingMethod(Enum):
|
| 18 |
greedy_search = "greedy_search"
|
|
|
|
| 107 |
repo_id = kwargs["repo_id"]
|
| 108 |
|
| 109 |
if "nn_model_file" in kwargs.keys():
|
| 110 |
+
main_logger.info("download nn_model_file. filename: {}, subfolder: {}".format(kwargs["nn_model_file"], kwargs["nn_model_file_sub_folder"]))
|
| 111 |
_ = huggingface_hub.hf_hub_download(
|
| 112 |
repo_id=repo_id,
|
| 113 |
filename=kwargs["nn_model_file"],
|
|
|
|
| 116 |
)
|
| 117 |
|
| 118 |
if "encoder_model_file" in kwargs.keys():
|
| 119 |
+
main_logger.info("download encoder_model_file. filename: {}, subfolder: {}".format(kwargs["encoder_model_file"], kwargs["encoder_model_file_sub_folder"]))
|
| 120 |
_ = huggingface_hub.hf_hub_download(
|
| 121 |
repo_id=repo_id,
|
| 122 |
filename=kwargs["encoder_model_file"],
|
|
|
|
| 125 |
)
|
| 126 |
|
| 127 |
if "decoder_model_file" in kwargs.keys():
|
| 128 |
+
main_logger.info("download decoder_model_file. filename: {}, subfolder: {}".format(kwargs["decoder_model_file"], kwargs["decoder_model_file_sub_folder"]))
|
| 129 |
_ = huggingface_hub.hf_hub_download(
|
| 130 |
repo_id=repo_id,
|
| 131 |
filename=kwargs["decoder_model_file"],
|
|
|
|
| 134 |
)
|
| 135 |
|
| 136 |
if "joiner_model_file" in kwargs.keys():
|
| 137 |
+
main_logger.info("download joiner_model_file. filename: {}, subfolder: {}".format(kwargs["joiner_model_file"], kwargs["joiner_model_file_sub_folder"]))
|
| 138 |
_ = huggingface_hub.hf_hub_download(
|
| 139 |
repo_id=repo_id,
|
| 140 |
filename=kwargs["joiner_model_file"],
|
|
|
|
| 143 |
)
|
| 144 |
|
| 145 |
if "tokens_file" in kwargs.keys():
|
| 146 |
+
main_logger.info("download tokens_file. filename: {}, subfolder: {}".format(kwargs["tokens_file"], kwargs["tokens_file_sub_folder"]))
|
| 147 |
_ = huggingface_hub.hf_hub_download(
|
| 148 |
repo_id=repo_id,
|
| 149 |
filename=kwargs["tokens_file"],
|
|
|
|
| 166 |
feat_config.fbank_opts.mel_opts.num_bins = num_mel_bins
|
| 167 |
feat_config.fbank_opts.frame_opts.dither = frame_dither
|
| 168 |
|
| 169 |
+
if not os.path.exists(nn_model_file):
|
| 170 |
+
raise AssertionError("nn_model_file not found. ")
|
| 171 |
+
|
| 172 |
config = sherpa.OfflineRecognizerConfig(
|
| 173 |
nn_model=nn_model_file,
|
| 174 |
tokens=tokens_file,
|
|
|
|
| 231 |
num_active_paths: int = 4,
|
| 232 |
**kwargs
|
| 233 |
):
|
| 234 |
+
if not local_model_dir.exists():
|
| 235 |
download_model(
|
| 236 |
local_model_dir=local_model_dir.as_posix(),
|
| 237 |
**kwargs,
|