HoneyTian commited on
Commit
168b5c0
Β·
1 Parent(s): 7e1376c
exception.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ class ExpectedError(Exception):
4
+ def __init__(self, status_code, message, traceback="", detail=""):
5
+ self.status_code = status_code
6
+ self.message = message
7
+ self.traceback = traceback
8
+ self.detail = detail
log.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import logging
4
+ from logging.handlers import TimedRotatingFileHandler
5
+ import os
6
+
7
+
8
+ def setup(log_directory: str):
9
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
10
+
11
+ stream_handler = logging.StreamHandler()
12
+ stream_handler.setLevel(logging.INFO)
13
+ stream_handler.setFormatter(logging.Formatter(fmt))
14
+
15
+ # main
16
+ main_logger = logging.getLogger("main")
17
+ main_logger.addHandler(stream_handler)
18
+ main_info_file_handler = TimedRotatingFileHandler(
19
+ filename=os.path.join(log_directory, "main.log"),
20
+ encoding="utf-8",
21
+ when="midnight",
22
+ interval=1,
23
+ backupCount=30
24
+ )
25
+ main_info_file_handler.setLevel(logging.INFO)
26
+ main_info_file_handler.setFormatter(logging.Formatter(fmt))
27
+ main_logger.addHandler(main_info_file_handler)
28
+
29
+ # http
30
+ http_logger = logging.getLogger("http")
31
+ http_file_handler = TimedRotatingFileHandler(
32
+ filename=os.path.join(log_directory, "http.log"),
33
+ encoding='utf-8',
34
+ when="midnight",
35
+ interval=1,
36
+ backupCount=30
37
+ )
38
+ http_file_handler.setLevel(logging.DEBUG)
39
+ http_file_handler.setFormatter(logging.Formatter(fmt))
40
+ http_logger.addHandler(http_file_handler)
41
+
42
+ # api
43
+ api_logger = logging.getLogger("api")
44
+ api_file_handler = TimedRotatingFileHandler(
45
+ filename=os.path.join(log_directory, "api.log"),
46
+ encoding='utf-8',
47
+ when="midnight",
48
+ interval=1,
49
+ backupCount=30
50
+ )
51
+ api_file_handler.setLevel(logging.DEBUG)
52
+ api_file_handler.setFormatter(logging.Formatter(fmt))
53
+ api_logger.addHandler(api_file_handler)
54
+
55
+ # alarm
56
+ alarm_logger = logging.getLogger("alarm")
57
+ alarm_file_handler = TimedRotatingFileHandler(
58
+ filename=os.path.join(log_directory, "alarm.log"),
59
+ encoding="utf-8",
60
+ when="midnight",
61
+ interval=1,
62
+ backupCount=30
63
+ )
64
+ alarm_file_handler.setLevel(logging.DEBUG)
65
+ alarm_file_handler.setFormatter(logging.Formatter(fmt))
66
+ alarm_logger.addHandler(alarm_file_handler)
67
+
68
+ debug_file_handler = TimedRotatingFileHandler(
69
+ filename=os.path.join(log_directory, "debug.log"),
70
+ encoding="utf-8",
71
+ when="D",
72
+ interval=1,
73
+ backupCount=7
74
+ )
75
+ debug_file_handler.setLevel(logging.DEBUG)
76
+ debug_file_handler.setFormatter(logging.Formatter(fmt))
77
+
78
+ info_file_handler = TimedRotatingFileHandler(
79
+ filename=os.path.join(log_directory, "info.log"),
80
+ encoding="utf-8",
81
+ when="D",
82
+ interval=1,
83
+ backupCount=7
84
+ )
85
+ info_file_handler.setLevel(logging.INFO)
86
+ info_file_handler.setFormatter(logging.Formatter(fmt))
87
+
88
+ error_file_handler = TimedRotatingFileHandler(
89
+ filename=os.path.join(log_directory, "error.log"),
90
+ encoding="utf-8",
91
+ when="D",
92
+ interval=1,
93
+ backupCount=7
94
+ )
95
+ error_file_handler.setLevel(logging.ERROR)
96
+ error_file_handler.setFormatter(logging.Formatter(fmt))
97
+
98
+ logging.basicConfig(
99
+ level=logging.DEBUG,
100
+ datefmt="%a, %d %b %Y %H:%M:%S",
101
+ handlers=[
102
+ debug_file_handler,
103
+ info_file_handler,
104
+ error_file_handler,
105
+ ]
106
+ )
107
+
108
+
109
+ if __name__ == "__main__":
110
+ pass
main.py CHANGED
@@ -2,25 +2,36 @@
2
  # -*- coding: utf-8 -*-
3
  import argparse
4
  from collections import defaultdict
 
 
 
 
 
5
  import platform
 
 
 
 
 
 
6
 
7
  import gradio as gr
 
 
8
 
9
- from examples import examples
10
- from models import model_map
11
- from project_settings import project_path
 
 
12
 
13
 
14
  def get_args():
15
  parser = argparse.ArgumentParser()
 
16
  parser.add_argument(
17
- "--examples_dir",
18
- default=(project_path / "data/examples").as_posix(),
19
- type=str
20
- )
21
- parser.add_argument(
22
- "--trained_model_dir",
23
- default=(project_path / "trained_models").as_posix(),
24
  type=str
25
  )
26
  args = parser.parse_args()
@@ -28,10 +39,10 @@ def get_args():
28
 
29
 
30
  def update_model_dropdown(language: str):
31
- if language not in model_map.keys():
32
  raise ValueError(f"Unsupported language: {language}")
33
 
34
- choices = model_map[language]
35
  choices = [c["repo_id"] for c in choices]
36
  return gr.Dropdown(
37
  choices=choices,
@@ -50,14 +61,109 @@ def build_html_output(s: str, style: str = "result_item_success"):
50
  """
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def process_uploaded_file(language: str,
54
  repo_id: str,
55
  decoding_method: str,
56
  num_active_paths: int,
57
  add_punctuation: str,
58
  in_filename: str,
 
59
  ):
60
- return "Dummy", build_html_output("Dummy")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  # css style is copied from
@@ -71,12 +177,22 @@ css = """
71
 
72
 
73
  def main():
 
 
 
 
 
 
 
 
 
 
74
  title = "# Automatic Speech Recognition with Next-gen Kaldi"
75
 
76
- language_choices = list(model_map.keys())
77
 
78
  language_to_models = defaultdict(list)
79
- for k, v in model_map.items():
80
  for m in v:
81
  repo_id = m["repo_id"]
82
  language_to_models[k].append(repo_id)
@@ -134,11 +250,11 @@ def main():
134
  uploaded_file,
135
  ],
136
  outputs=[uploaded_output, uploaded_html_info],
137
- fn=process_uploaded_file,
138
  )
139
 
140
  upload_button.click(
141
- process_uploaded_file,
142
  inputs=[
143
  language_radio,
144
  model_dropdown,
 
2
  # -*- coding: utf-8 -*-
3
  import argparse
4
  from collections import defaultdict
5
+ from datetime import datetime
6
+ import functools
7
+ import io
8
+ import logging
9
+ from pathlib import Path
10
  import platform
11
+ import time
12
+
13
+ from project_settings import project_path, log_directory
14
+ import log
15
+
16
+ log.setup(log_directory=log_directory)
17
 
18
  import gradio as gr
19
+ import torch
20
+ import torchaudio
21
 
22
+ from toolbox.k2_sherpa.examples import examples
23
+ from toolbox.k2_sherpa import decode, models
24
+ from toolbox.k2_sherpa.utils import audio_convert
25
+
26
+ main_logger = logging.getLogger("main")
27
 
28
 
29
  def get_args():
30
  parser = argparse.ArgumentParser()
31
+
32
  parser.add_argument(
33
+ "--pretrained_model_dir",
34
+ default=(project_path / "pretrained_models").as_posix(),
 
 
 
 
 
35
  type=str
36
  )
37
  args = parser.parse_args()
 
39
 
40
 
41
  def update_model_dropdown(language: str):
42
+ if language not in models.model_map.keys():
43
  raise ValueError(f"Unsupported language: {language}")
44
 
45
+ choices = models.model_map[language]
46
  choices = [c["repo_id"] for c in choices]
47
  return gr.Dropdown(
48
  choices=choices,
 
61
  """
62
 
63
 
64
+ @torch.no_grad()
65
+ def process(
66
+ language: str,
67
+ repo_id: str,
68
+ decoding_method: str,
69
+ num_active_paths: int,
70
+ add_punctuation: str,
71
+ in_filename: str,
72
+ pretrained_model_dir: Path,
73
+ ):
74
+ main_logger.info("language: {}".format(language))
75
+ main_logger.info("repo_id: {}".format(repo_id))
76
+ main_logger.info("decoding_method: {}".format(decoding_method))
77
+ main_logger.info("num_active_paths: {}".format(num_active_paths))
78
+ main_logger.info("in_filename: {}".format(in_filename))
79
+
80
+ m_list = models.model_map.get(language)
81
+ if m_list is None:
82
+ raise AssertionError("language invalid: {}".format(language))
83
+
84
+ m_dict = None
85
+ for m in m_list:
86
+ if m["repo_id"] == repo_id:
87
+ m_dict = m
88
+ if m_dict is None:
89
+ raise AssertionError("repo_id invalid: {}".format(repo_id))
90
+
91
+ local_model_dir = pretrained_model_dir / "huggingface" / repo_id
92
+
93
+ out_filename = io.BytesIO()
94
+ audio_convert(in_filename, out_filename)
95
+
96
+ recognizer = models.load_recognizer(
97
+ repo_id=m_dict["repo_id"],
98
+ nn_model_file=m_dict["nn_model_file"],
99
+ tokens_file=m_dict["tokens_file"],
100
+ sub_folder=m_dict["sub_folder"],
101
+ local_model_dir=local_model_dir,
102
+ recognizer_type=m_dict["recognizer_type"],
103
+ decoding_method=decoding_method,
104
+ num_active_paths=num_active_paths,
105
+ )
106
+
107
+ now = datetime.now()
108
+ date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
109
+ logging.info(f"Started at {date_time}")
110
+ start = time.time()
111
+
112
+ text = decode.decode_by_recognizer(recognizer=recognizer,
113
+ filename=out_filename,
114
+ )
115
+
116
+ date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
117
+ end = time.time()
118
+
119
+ metadata = torchaudio.info(out_filename)
120
+ duration = metadata.num_frames / 16000
121
+ rtf = (end - start) / duration
122
+
123
+ main_logger.info(f"Finished at {date_time} s. Elapsed: {end - start: .3f} s")
124
+
125
+ info = f"""
126
+ Wave duration : {duration: .3f} s <br/>
127
+ Processing time: {end - start: .3f} s <br/>
128
+ RTF: {end - start: .3f}/{duration: .3f} = {rtf:.3f} <br/>
129
+ """
130
+
131
+ main_logger.info(info)
132
+ main_logger.info(f"\nrepo_id: {repo_id}\nhyp: {text}")
133
+
134
+ return text, build_html_output(info)
135
+
136
+
137
  def process_uploaded_file(language: str,
138
  repo_id: str,
139
  decoding_method: str,
140
  num_active_paths: int,
141
  add_punctuation: str,
142
  in_filename: str,
143
+ pretrained_model_dir: Path,
144
  ):
145
+ if in_filename is None or in_filename == "":
146
+ return "", build_html_output(
147
+ "Please first upload a file and then click "
148
+ 'the button "submit for recognition"',
149
+ "result_item_error",
150
+ )
151
+ main_logger.info(f"Processing uploaded file: {in_filename}")
152
+
153
+ try:
154
+ return process(
155
+ in_filename=in_filename,
156
+ language=language,
157
+ repo_id=repo_id,
158
+ decoding_method=decoding_method,
159
+ num_active_paths=num_active_paths,
160
+ add_punctuation=add_punctuation,
161
+ pretrained_model_dir=pretrained_model_dir,
162
+ )
163
+ except Exception as e:
164
+ msg = "transcribe error: {}".format(str(e))
165
+ main_logger.info(msg)
166
+ return "", build_html_output(msg, "result_item_error")
167
 
168
 
169
  # css style is copied from
 
177
 
178
 
179
  def main():
180
+ args = get_args()
181
+
182
+ pretrained_model_dir = Path(args.pretrained_model_dir)
183
+ pretrained_model_dir.mkdir(exist_ok=True)
184
+
185
+ process_uploaded_file_ = functools.partial(
186
+ process_uploaded_file,
187
+ pretrained_model_dir=pretrained_model_dir,
188
+ )
189
+
190
  title = "# Automatic Speech Recognition with Next-gen Kaldi"
191
 
192
+ language_choices = list(models.model_map.keys())
193
 
194
  language_to_models = defaultdict(list)
195
+ for k, v in models.model_map.items():
196
  for m in v:
197
  repo_id = m["repo_id"]
198
  language_to_models[k].append(repo_id)
 
250
  uploaded_file,
251
  ],
252
  outputs=[uploaded_output, uploaded_html_info],
253
+ fn=process_uploaded_file_,
254
  )
255
 
256
  upload_button.click(
257
+ process_uploaded_file_,
258
  inputs=[
259
  language_radio,
260
  model_dropdown,
project_settings.py CHANGED
@@ -7,6 +7,9 @@ from pathlib import Path
7
  project_path = os.path.abspath(os.path.dirname(__file__))
8
  project_path = Path(project_path)
9
 
 
 
 
10
  temp_directory = project_path / "temp"
11
  temp_directory.mkdir(parents=True, exist_ok=True)
12
 
 
7
  project_path = os.path.abspath(os.path.dirname(__file__))
8
  project_path = Path(project_path)
9
 
10
+ log_directory = project_path / "log"
11
+ log_directory.mkdir(parents=True, exist_ok=True)
12
+
13
  temp_directory = project_path / "temp"
14
  temp_directory.mkdir(parents=True, exist_ok=True)
15
 
toolbox/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == "__main__":
5
+ pass
toolbox/k2_sherpa/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == "__main__":
5
+ pass
decode.py β†’ toolbox/k2_sherpa/decode.py RENAMED
File without changes
examples.py β†’ toolbox/k2_sherpa/examples.py RENAMED
File without changes
models.py β†’ toolbox/k2_sherpa/models.py RENAMED
@@ -24,35 +24,36 @@ model_map = {
24
  "Chinese": [
25
  {
26
  "repo_id": "csukuangfj/wenet-chinese-model",
27
- "model_file": "final.zip",
28
  "tokens_file": "units.txt",
29
  "subfolder": ".",
 
30
  }
31
  ]
32
  }
33
 
34
 
35
  def download_model(repo_id: str,
36
- nn_model_filename: str,
37
- tokens_filename: str,
38
  sub_folder: str,
39
  local_model_dir: str,
40
  ):
41
 
42
- nn_model_filename = huggingface_hub.hf_hub_download(
43
  repo_id=repo_id,
44
- filename=nn_model_filename,
45
  subfolder=sub_folder,
46
  local_dir=local_model_dir,
47
  )
48
 
49
- tokens_filename = huggingface_hub.hf_hub_download(
50
  repo_id=repo_id,
51
- filename=tokens_filename,
52
  subfolder=sub_folder,
53
  local_dir=local_model_dir,
54
  )
55
- return nn_model_filename, tokens_filename
56
 
57
 
58
  @lru_cache(maxsize=10)
@@ -82,25 +83,34 @@ def load_sherpa_offline_recognizer(nn_model_file: str,
82
  return recognizer
83
 
84
 
85
- def load_recognizer(
86
- repo_id: str,
87
- nn_model_filename: str,
88
- tokens_filename: str,
89
  sub_folder: str,
90
  local_model_dir: str,
91
- recognizer_type: EnumRecognizerType,
92
  decoding_method: EnumDecodingMethod = EnumDecodingMethod.greedy_search,
 
93
  ):
94
  if not os.path.exists(local_model_dir):
95
  download_model(
96
  repo_id=repo_id,
97
- nn_model_filename=nn_model_filename,
98
- tokens_filename=tokens_filename,
99
  sub_folder=sub_folder,
100
  local_model_dir=local_model_dir,
101
  )
102
 
103
- return
 
 
 
 
 
 
 
 
 
104
 
105
 
106
  if __name__ == "__main__":
 
24
  "Chinese": [
25
  {
26
  "repo_id": "csukuangfj/wenet-chinese-model",
27
+ "nn_model_file": "final.zip",
28
  "tokens_file": "units.txt",
29
  "subfolder": ".",
30
+ "recognizer_type": EnumRecognizerType.sherpa_offline_recognizer.value,
31
  }
32
  ]
33
  }
34
 
35
 
36
  def download_model(repo_id: str,
37
+ nn_model_file: str,
38
+ tokens_file: str,
39
  sub_folder: str,
40
  local_model_dir: str,
41
  ):
42
 
43
+ nn_model_file = huggingface_hub.hf_hub_download(
44
  repo_id=repo_id,
45
+ filename=nn_model_file,
46
  subfolder=sub_folder,
47
  local_dir=local_model_dir,
48
  )
49
 
50
+ tokens_file = huggingface_hub.hf_hub_download(
51
  repo_id=repo_id,
52
+ filename=tokens_file,
53
  subfolder=sub_folder,
54
  local_dir=local_model_dir,
55
  )
56
+ return nn_model_file, tokens_file
57
 
58
 
59
  @lru_cache(maxsize=10)
 
83
  return recognizer
84
 
85
 
86
+ def load_recognizer(repo_id: str,
87
+ nn_model_file: str,
88
+ tokens_file: str,
 
89
  sub_folder: str,
90
  local_model_dir: str,
91
+ recognizer_type: str,
92
  decoding_method: EnumDecodingMethod = EnumDecodingMethod.greedy_search,
93
+ num_active_paths: int = 4,
94
  ):
95
  if not os.path.exists(local_model_dir):
96
  download_model(
97
  repo_id=repo_id,
98
+ nn_model_file=nn_model_file,
99
+ tokens_file=tokens_file,
100
  sub_folder=sub_folder,
101
  local_model_dir=local_model_dir,
102
  )
103
 
104
+ if recognizer_type == EnumRecognizerType.sherpa_offline_recognizer.value:
105
+ recognizer = load_sherpa_offline_recognizer(
106
+ nn_model_file=nn_model_file,
107
+ tokens_file=tokens_file,
108
+ decoding_method=decoding_method,
109
+ num_active_paths=num_active_paths,
110
+ )
111
+ else:
112
+ raise NotImplementedError("recognizer_type not support: {}".format(recognizer_type.value))
113
+ return recognizer
114
 
115
 
116
  if __name__ == "__main__":
toolbox/k2_sherpa/utils.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import librosa
4
+ import numpy as np
5
+ from scipy.io import wavfile
6
+
7
+
8
+ def audio_convert(in_filename: str,
9
+ out_filename: str,
10
+ sample_rate: int = 16000):
11
+ signal, _ = librosa.load(in_filename, sr=sample_rate)
12
+ signal *= 32768.0
13
+ signal = np.array(signal, dtype=np.int16)
14
+
15
+ wavfile.write(
16
+ out_filename,
17
+ rate=sample_rate,
18
+ data=signal
19
+ )
20
+ return out_filename
21
+
22
+
23
+ if __name__ == "__main__":
24
+ pass