konieshadow commited on
Commit
48811fe
·
1 Parent(s): 406e6ac

更新ASR模块,重构转录逻辑,新增基于Transformers的distil-whisper模型支持,并优化音频处理流程。

Browse files
examples/simple_asr.py CHANGED
@@ -14,7 +14,7 @@ from pathlib import Path
14
  sys.path.insert(0, str(Path(__file__).parent.parent))
15
 
16
  from src.podcast_transcribe.audio import load_audio
17
- from src.podcast_transcribe.asr.asr_distil_whisper import transcribe_audio
18
 
19
  logger = logging.getLogger("asr_example")
20
 
@@ -43,7 +43,7 @@ def main():
43
 
44
  # 进行转录
45
  print("开始转录...")
46
- result = transcribe_audio(audio, model_name=model, device=device)
47
 
48
  # 输出结果
49
  print("\n转录结果:")
 
14
  sys.path.insert(0, str(Path(__file__).parent.parent))
15
 
16
  from src.podcast_transcribe.audio import load_audio
17
+ from src.podcast_transcribe.asr.asr_router import transcribe_audio
18
 
19
  logger = logging.getLogger("asr_example")
20
 
 
43
 
44
  # 进行转录
45
  print("开始转录...")
46
+ result = transcribe_audio(audio, provider="distil_whisper_transformers", model_name=model, device=device)
47
 
48
  # 输出结果
49
  print("\n转录结果:")
src/podcast_transcribe/asr/asr_base.py CHANGED
@@ -46,6 +46,54 @@ class BaseTranscriber:
46
  """
47
  raise NotImplementedError("子类必须实现_load_model方法")
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def _prepare_audio(self, audio: AudioSegment) -> AudioSegment:
50
  """
51
  准备音频数据
 
46
  """
47
  raise NotImplementedError("子类必须实现_load_model方法")
48
 
49
+ def transcribe(self, audio: AudioSegment) -> TranscriptionResult:
50
+ """
51
+ 转录音频,针对distil-whisper模型取消分块处理,直接处理整个音频。
52
+
53
+ 参数:
54
+ audio: 要转录的AudioSegment对象
55
+
56
+ 返回:
57
+ TranscriptionResult对象,包含转录结果
58
+ """
59
+ logger.info(f"开始转录 {len(audio)/1000:.2f} 秒的音频") # 移除了模型名称,因为基类不知道具体模型
60
+
61
+ # 直接处理整个音频,不进行分块
62
+ processed_audio = self._prepare_audio(audio)
63
+ samples = np.array(processed_audio.get_array_of_samples(), dtype=np.float32) / 32768.0
64
+
65
+ try:
66
+ model_result = self._perform_transcription(samples)
67
+ text = self._get_text_from_result(model_result)
68
+ segments = self._convert_segments(model_result)
69
+ language = self._detect_language(text)
70
+
71
+ logger.info(f"转录完成,语言: {language},文本长度: {len(text)},分段数: {len(segments)}")
72
+ return TranscriptionResult(text=text, segments=segments, language=language)
73
+ except Exception as e:
74
+ logger.error(f"转录失败: {str(e)}", exc_info=True)
75
+ raise RuntimeError(f"转录失败: {str(e)}")
76
+
77
+ def _get_text_from_result(self, result):
78
+ """
79
+ 从结果中获取文本
80
+
81
+ 参数:
82
+ result: 模型的转录结果
83
+
84
+ 返回:
85
+ 转录的文本
86
+ """
87
+ return result.get("text", "")
88
+
89
+ def _perform_transcription(self, audio_data):
90
+ """执行转录的抽象方法,由子类实现"""
91
+ raise NotImplementedError("子类必须实现_perform_transcription方法")
92
+
93
+ def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
94
+ """将模型结果转换为分段的抽象方法,由子类实现"""
95
+ raise NotImplementedError("子类必须实现_convert_segments方法")
96
+
97
  def _prepare_audio(self, audio: AudioSegment) -> AudioSegment:
98
  """
99
  准备音频数据
src/podcast_transcribe/asr/{asr_distil_whisper.py → asr_distil_whisper_transformers.py} RENAMED
@@ -1,5 +1,5 @@
1
  """
2
- 基于MLX或Transformers实现的语音识别模块,使用distil-whisper模型
3
  """
4
 
5
  import os
@@ -15,145 +15,7 @@ from .asr_base import BaseTranscriber, TranscriptionResult
15
  logger = logging.getLogger("asr")
16
 
17
 
18
- class DistilWhisperTranscriber(BaseTranscriber):
19
- """抽象基类:Distil Whisper转录器的共享实现"""
20
-
21
- def __init__(
22
- self,
23
- model_name: str,
24
- **kwargs
25
- ):
26
- """
27
- 初始化转录器
28
-
29
- 参数:
30
- model_name: 模型名称
31
- **kwargs: 其他参数
32
- """
33
- super().__init__(model_name=model_name, **kwargs)
34
-
35
- def transcribe(self, audio: AudioSegment) -> TranscriptionResult:
36
- """
37
- 转录音频,针对distil-whisper模型取消分块处理,直接处理整个音频。
38
-
39
- 参数:
40
- audio: 要转录的AudioSegment对象
41
- chunk_duration_s: 分块处理的块时长(秒)- 此参数被忽略
42
- overlap_s: 分块间的重叠时长(秒)- 此参数被忽略
43
-
44
- 返回:
45
- TranscriptionResult对象,包含转录结果
46
- """
47
- logger.info(f"开始转录 {len(audio)/1000:.2f} 秒的音频(distil-whisper模型)")
48
-
49
- # 直接处理整个音频,不进行分块
50
- processed_audio = self._prepare_audio(audio)
51
- samples = np.array(processed_audio.get_array_of_samples(), dtype=np.float32) / 32768.0
52
-
53
- try:
54
- model_result = self._perform_transcription(samples)
55
- text = self._get_text_from_result(model_result)
56
- segments = self._convert_segments(model_result)
57
- language = self._detect_language(text)
58
-
59
- logger.info(f"转录完成,语言: {language},文本长度: {len(text)},分段数: {len(segments)}")
60
- return TranscriptionResult(text=text, segments=segments, language=language)
61
- except Exception as e:
62
- logger.error(f"转录失败: {str(e)}", exc_info=True)
63
- raise RuntimeError(f"转录失败: {str(e)}")
64
-
65
- def _get_text_from_result(self, result):
66
- """
67
- 从结果中获取文本
68
-
69
- 参数:
70
- result: 模型的转录结果
71
-
72
- 返回:
73
- 转录的文本
74
- """
75
- return result.get("text", "")
76
-
77
- def _load_model(self):
78
- """加载模型的抽象方法,由子类实现"""
79
- raise NotImplementedError("子类必须实现_load_model方法")
80
-
81
- def _perform_transcription(self, audio_data):
82
- """执行转录的抽象方法,由子类实现"""
83
- raise NotImplementedError("子类必须实现_perform_transcription方法")
84
-
85
- def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
86
- """将模型结果转换为分段的抽象方法,由子类实现"""
87
- raise NotImplementedError("子类必须实现_convert_segments方法")
88
-
89
-
90
- class MLXDistilWhisperTranscriber(DistilWhisperTranscriber):
91
- """使用MLX加载和运行distil-whisper模型的转录器"""
92
-
93
- def __init__(
94
- self,
95
- model_name: str = "mlx-community/distil-whisper-large-v3",
96
- ):
97
- """
98
- 初始化转录器
99
-
100
- 参数:
101
- model_name: 模型名称
102
- """
103
- super().__init__(model_name=model_name)
104
-
105
- def _load_model(self):
106
- """加载Distil Whisper MLX模型"""
107
- try:
108
- # 懒加载mlx-whisper
109
- try:
110
- import mlx_whisper
111
- except ImportError:
112
- raise ImportError("请先安装mlx-whisper库: pip install mlx-whisper")
113
-
114
- logger.info(f"开始加载模型 {self.model_name}")
115
- self.model = mlx_whisper.load_models.load_model(self.model_name)
116
- logger.info(f"模型加载成功")
117
- except Exception as e:
118
- logger.error(f"加载模型失败: {str(e)}", exc_info=True)
119
- raise RuntimeError(f"加载模型失败: {str(e)}")
120
-
121
- def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
122
- """
123
- 将模型的分段结果转换为所需格式
124
-
125
- 参数:
126
- result: 模型返回的结果
127
-
128
- 返回:
129
- 转换后的分段列表
130
- """
131
- segments = []
132
-
133
- for segment in result.get("segments", []):
134
- segments.append({
135
- "start": segment.get("start", 0.0),
136
- "end": segment.get("end", 0.0),
137
- "text": segment.get("text", "").strip()
138
- })
139
-
140
- return segments
141
-
142
- def _perform_transcription(self, audio_data):
143
- """
144
- 执行转录
145
-
146
- 参数:
147
- audio_data: 音频数据(numpy数组)
148
-
149
- 返回:
150
- 模型的转录结果
151
- """
152
- from mlx_whisper import transcribe
153
- return transcribe(audio_data, path_or_hf_repo=self.model_name)
154
-
155
-
156
- class TransformersDistilWhisperTranscriber(DistilWhisperTranscriber):
157
  """使用Transformers加载和运行distil-whisper模型的转录器"""
158
 
159
  def __init__(
@@ -285,35 +147,27 @@ class TransformersDistilWhisperTranscriber(DistilWhisperTranscriber):
285
  # 如果新格式失败,回退到简单调用
286
  return self.pipeline(audio_data)
287
 
288
-
289
  # 统一的接口函数
290
  def transcribe_audio(
291
  audio_segment: AudioSegment,
292
  model_name: str = None,
293
- backend: Literal["mlx", "transformers"] = "transformers",
294
  device: str = "cpu",
295
  ) -> TranscriptionResult:
296
  """
297
- 使用Distil Whisper模型转录音频
298
 
299
  参数:
300
  audio_segment: 输入的AudioSegment对象
301
  model_name: 使用的模型名称,如果不指定则使用默认模型
302
- backend: 后端类型,'mlx'或'transformers'
303
- device: 推理设备,仅对transformers后端有效
304
 
305
  返回:
306
  TranscriptionResult对象,包含转录的文本、分段和语言
307
  """
308
- logger.info(f"调用transcribe_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒,后端: {backend}")
309
 
310
- if backend == "mlx":
311
- default_model = "mlx-community/distil-whisper-large-v3"
312
- model = model_name or default_model
313
- transcriber = MLXDistilWhisperTranscriber(model_name=model)
314
- else: # transformers
315
- default_model = "distil-whisper/distil-large-v3.5"
316
- model = model_name or default_model
317
- transcriber = TransformersDistilWhisperTranscriber(model_name=model, device=device)
318
 
319
- return transcriber.transcribe(audio_segment)
 
1
  """
2
+ 基于Transformers实现的语音识别模块,使用distil-whisper模型
3
  """
4
 
5
  import os
 
15
  logger = logging.getLogger("asr")
16
 
17
 
18
+ class TransformersDistilWhisperTranscriber(BaseTranscriber):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  """使用Transformers加载和运行distil-whisper模型的转录器"""
20
 
21
  def __init__(
 
147
  # 如果新格式失败,回退到简单调用
148
  return self.pipeline(audio_data)
149
 
 
150
  # 统一的接口函数
151
  def transcribe_audio(
152
  audio_segment: AudioSegment,
153
  model_name: str = None,
 
154
  device: str = "cpu",
155
  ) -> TranscriptionResult:
156
  """
157
+ 使用Distil Whisper模型转录音频 (Transformers后端)
158
 
159
  参数:
160
  audio_segment: 输入的AudioSegment对象
161
  model_name: 使用的模型名称,如果不指定则使用默认模型
162
+ device: 推理设备,'cpu'或'cuda'
 
163
 
164
  返回:
165
  TranscriptionResult对象,包含转录的文本、分段和语言
166
  """
167
+ logger.info(f"调用 transcribe_audio 函数 (Transformers后端),音频长度: {len(audio_segment)/1000:.2f}秒,设备: {device}")
168
 
169
+ default_model = "distil-whisper/distil-large-v3.5"
170
+ model = model_name or default_model
171
+ transcriber = TransformersDistilWhisperTranscriber(model_name=model, device=device)
 
 
 
 
 
172
 
173
+ return transcriber.transcribe(audio_segment)
src/podcast_transcribe/asr/asr_router.py CHANGED
@@ -4,11 +4,10 @@ ASR模型调用路由器
4
  """
5
 
6
  import logging
7
- from typing import Dict, Any, Literal, Optional, Callable
8
  from pydub import AudioSegment
9
  import spaces
10
  from .asr_base import TranscriptionResult
11
- from . import asr_distil_whisper
12
 
13
  # 配置日志
14
  logger = logging.getLogger("asr")
@@ -25,8 +24,8 @@ class ASRRouter:
25
  # 定义支持的provider配置
26
  self._provider_configs = {
27
  "distil_whisper_transformers": {
28
- "module_path": "asr_distil_whisper",
29
- "function_name": "transcribe_audio",
30
  "default_model": "distil-whisper/distil-large-v3.5",
31
  "supported_params": ["model_name", "device"],
32
  "description": "基于Transformers的Distil Whisper模型"
@@ -50,8 +49,9 @@ class ASRRouter:
50
  module_path = self._provider_configs[provider]["module_path"]
51
  logger.info(f"获取模块: {module_path}")
52
 
53
- # 所有provider现在都指向同一个模块
54
- module = asr_distil_whisper
 
55
 
56
  self._loaded_modules[provider] = module
57
  logger.info(f"模块 {module_path} 获取成功")
@@ -98,6 +98,10 @@ class ASRRouter:
98
  if "model_name" not in filtered_params and "model_name" in supported_params:
99
  filtered_params["model_name"] = self._provider_configs[provider]["default_model"]
100
 
 
 
 
 
101
  return filtered_params
102
 
103
  def transcribe(
@@ -179,17 +183,17 @@ def transcribe_audio(
179
  provider: str = "distil_whisper_transformers",
180
  model_name: Optional[str] = None,
181
  device: str = "cpu",
182
- backend: str = "transformers",
183
  **kwargs
184
  ) -> TranscriptionResult:
 
 
 
185
  # 准备参数
186
  params = kwargs.copy()
187
  if model_name is not None:
188
  params["model_name"] = model_name
189
- if device != "cpu":
190
  params["device"] = device
191
- if backend is not None:
192
- params["backend"] = backend
193
 
194
  return _router.transcribe(audio_segment, provider, **params)
195
 
 
4
  """
5
 
6
  import logging
7
+ from typing import Dict, Any, Optional, Callable
8
  from pydub import AudioSegment
9
  import spaces
10
  from .asr_base import TranscriptionResult
 
11
 
12
  # 配置日志
13
  logger = logging.getLogger("asr")
 
24
  # 定义支持的provider配置
25
  self._provider_configs = {
26
  "distil_whisper_transformers": {
27
+ "module_path": ".asr_distil_whisper_transformers",
28
+ "function_name": "transcribe_audio",
29
  "default_model": "distil-whisper/distil-large-v3.5",
30
  "supported_params": ["model_name", "device"],
31
  "description": "基于Transformers的Distil Whisper模型"
 
49
  module_path = self._provider_configs[provider]["module_path"]
50
  logger.info(f"获取模块: {module_path}")
51
 
52
+ # 使用 importlib 动态导入模块
53
+ import importlib
54
+ module = importlib.import_module(module_path, package=__package__)
55
 
56
  self._loaded_modules[provider] = module
57
  logger.info(f"模块 {module_path} 获取成功")
 
98
  if "model_name" not in filtered_params and "model_name" in supported_params:
99
  filtered_params["model_name"] = self._provider_configs[provider]["default_model"]
100
 
101
+ # 对于 Transformers backend,如果 device 未指定,则默认为 cpu
102
+ if provider == "distil_whisper_transformers" and "device" in supported_params and "device" not in filtered_params:
103
+ filtered_params["device"] = "cpu"
104
+
105
  return filtered_params
106
 
107
  def transcribe(
 
183
  provider: str = "distil_whisper_transformers",
184
  model_name: Optional[str] = None,
185
  device: str = "cpu",
 
186
  **kwargs
187
  ) -> TranscriptionResult:
188
+ """
189
+ 统一的音频转录接口,通过路由器选择后端
190
+ """
191
  # 准备参数
192
  params = kwargs.copy()
193
  if model_name is not None:
194
  params["model_name"] = model_name
195
+ if device != "cpu": # 只有当 device 不是默认值才传递,或者根据需要传递所有支持的参数
196
  params["device"] = device
 
 
197
 
198
  return _router.transcribe(audio_segment, provider, **params)
199